Compare commits
116 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
549f32a743 | ||
|
|
e3b3452a73 | ||
|
|
62350d975d | ||
|
|
8d84b326b8 | ||
|
|
16079a3cba | ||
|
|
ff330a5487 | ||
|
|
94b3882d30 | ||
|
|
81544ca8b3 | ||
|
|
b7f4dd835e | ||
|
|
7e2380e4ed | ||
|
|
7f3cfd54b0 | ||
|
|
e083f2c629 | ||
|
|
e33858f110 | ||
|
|
da01a33152 | ||
|
|
8ca920a114 | ||
|
|
5f3d449a66 | ||
|
|
13735e7dfb | ||
|
|
a38d5c3a25 | ||
|
|
5bae637c67 | ||
|
|
12e488ba80 | ||
|
|
ad30c63c69 | ||
|
|
a116eff7df | ||
|
|
01bc355dde | ||
|
|
8e05f3c360 | ||
|
|
fde988dd4e | ||
|
|
91401ad14f | ||
|
|
280194647c | ||
|
|
2e0a542f33 | ||
|
|
b988694da7 | ||
|
|
512c4d0f73 | ||
|
|
5525fb1470 | ||
|
|
4db735e026 | ||
|
|
c8c79c39d1 | ||
|
|
bcfb76d8ca | ||
|
|
2d9aaf8fc9 | ||
|
|
8a3905c09a | ||
|
|
54cd8a46fa | ||
|
|
1b83bf261a | ||
|
|
2a7d22dab1 | ||
|
|
f7494b0cfb | ||
|
|
9ca91d59ec | ||
|
|
11feaa6e68 | ||
|
|
18d4b2304e | ||
|
|
2f45e9c33a | ||
|
|
f7df10cb66 | ||
|
|
46e9a2f5b2 | ||
|
|
69b8d2e0a1 | ||
|
|
0ddd2e9fea | ||
|
|
01c95f5bc4 | ||
|
|
e0bf44d82f | ||
|
|
f328e84ea7 | ||
|
|
c81f5015a1 | ||
|
|
e2b086e2f7 | ||
|
|
da632565d5 | ||
|
|
556b667cc0 | ||
|
|
82c9825da8 | ||
|
|
26b30f0dbe | ||
|
|
be3b69c65c | ||
|
|
07cab6949e | ||
|
|
18d58ce124 | ||
|
|
b8f8837a8f | ||
|
|
0c796c8cfc | ||
|
|
b14fbc29b7 | ||
|
|
6e29f97881 | ||
|
|
a164939161 | ||
|
|
09ab11ef01 | ||
|
|
ac34edec7f | ||
|
|
6dd8ffa037 | ||
|
|
eaed3f40a2 | ||
|
|
e48f39375e | ||
|
|
9b7b651ef9 | ||
|
|
b5623cb9c2 | ||
|
|
144d12b463 | ||
|
|
fa452f5518 | ||
|
|
a159d21d45 | ||
|
|
3a00bbf44d | ||
|
|
9f5e94fa8f | ||
|
|
87e1daa733 | ||
|
|
f5900179e0 | ||
|
|
51e162970e | ||
|
|
0b339ad0f6 | ||
|
|
60693d6a29 | ||
|
|
eea53a6e9e | ||
|
|
8a19181a38 | ||
|
|
94d835c7ae | ||
|
|
d9e25ad69f | ||
|
|
75244fbd8b | ||
|
|
5ce84edc3d | ||
|
|
1c683087f4 | ||
|
|
85a3b39cbc | ||
|
|
cc6c24f0c3 | ||
|
|
c733b6419c | ||
|
|
c853c5b60b | ||
|
|
053a08f5b7 | ||
|
|
f7227cd1c1 | ||
|
|
861e245062 | ||
|
|
8f0fc7db56 | ||
|
|
3dd06fa70e | ||
|
|
86a855e7bc | ||
|
|
b3110d4ad8 | ||
|
|
602004ad34 | ||
|
|
a8b4f0bb7e | ||
|
|
24cc8be085 | ||
|
|
a96d7aef8d | ||
|
|
cbe299583b | ||
|
|
68c70a362b | ||
|
|
a78c346371 | ||
|
|
102763b94d | ||
|
|
ad65765ba8 | ||
|
|
d04fd7cb87 | ||
|
|
b398cbb591 | ||
|
|
19b97e985c | ||
|
|
93bf74a320 | ||
|
|
7daae23bbb | ||
|
|
0d0a3f15cc | ||
|
|
04fbb38861 |
3
.gitattributes
vendored
3
.gitattributes
vendored
@@ -1,8 +1,11 @@
|
||||
* text=auto eol=lf
|
||||
|
||||
backend-python/rwkv_pip/** linguist-vendored
|
||||
backend-python/wkv_cuda_utils/** linguist-vendored
|
||||
backend-python/get-pip.py linguist-vendored
|
||||
backend-python/convert_model.py linguist-vendored
|
||||
backend-python/convert_safetensors.py linguist-vendored
|
||||
backend-python/convert_pytorch_to_ggml.py linguist-vendored
|
||||
backend-python/utils/midi.py linguist-vendored
|
||||
build/** linguist-vendored
|
||||
finetune/lora/** linguist-vendored
|
||||
|
||||
9
.github/dependabot.yml
vendored
Normal file
9
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
commit-message:
|
||||
prefix: "chore"
|
||||
include: "scope"
|
||||
53
.github/workflows/release.yml
vendored
53
.github/workflows/release.yml
vendored
@@ -48,16 +48,13 @@ jobs:
|
||||
id: cp310
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
override: true
|
||||
target: wasm32-unknown-unknown
|
||||
- uses: crazy-max/ghaction-chocolatey@v2
|
||||
with:
|
||||
args: install upx
|
||||
- run: |
|
||||
Start-BitsTransfer https://github.com/josStorer/LibreHardwareMonitor.Console/releases/download/v0.1.0/LibreHardwareMonitor.Console.zip ./LibreHardwareMonitor.Console.zip
|
||||
Start-BitsTransfer https://github.com/josStorer/ai00_rwkv_server/releases/latest/download/webgpu_server_windows_x86_64.exe ./backend-rust/webgpu_server.exe
|
||||
Start-BitsTransfer https://github.com/josStorer/web-rwkv-converter/releases/latest/download/web-rwkv-converter_windows_x86_64.exe ./backend-rust/web-rwkv-converter.exe
|
||||
Start-BitsTransfer https://github.com/josStorer/LibreHardwareMonitor.Console/releases/latest/download/LibreHardwareMonitor.Console.zip ./LibreHardwareMonitor.Console.zip
|
||||
Expand-Archive ./LibreHardwareMonitor.Console.zip -DestinationPath ./components/LibreHardwareMonitor.Console
|
||||
Start-BitsTransfer https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip ./python-3.10.11-embed-amd64.zip
|
||||
Expand-Archive ./python-3.10.11-embed-amd64.zip -DestinationPath ./py310
|
||||
@@ -67,13 +64,11 @@ jobs:
|
||||
Copy-Item -Path "${{ steps.cp310.outputs.python-path }}/../include" -Destination "py310/include" -Recurse
|
||||
Copy-Item -Path "${{ steps.cp310.outputs.python-path }}/../libs" -Destination "py310/libs" -Recurse
|
||||
./py310/python -m pip install cyac==1.9
|
||||
git clone https://github.com/josStorer/ai00_rwkv_server --depth=1
|
||||
cd ai00_rwkv_server
|
||||
cargo build --release
|
||||
mv ./target/release/ai00_server.exe ../backend-rust/webgpu_server.exe
|
||||
cd ..
|
||||
go install github.com/wailsapp/wails/v2/cmd/wails@latest
|
||||
del ./backend-python/rwkv_pip/cpp/librwkv.dylib
|
||||
del ./backend-python/rwkv_pip/cpp/librwkv.so
|
||||
(Get-Content -Path ./backend-golang/app.go) -replace "//go:custom_build windows ", "" | Set-Content -Path ./backend-golang/app.go
|
||||
(Get-Content -Path ./backend-golang/utils.go) -replace "//go:custom_build windows ", "" | Set-Content -Path ./backend-golang/utils.go
|
||||
make
|
||||
Rename-Item -Path "build/bin/RWKV-Runner.exe" -NewName "RWKV-Runner_windows_x64.exe"
|
||||
|
||||
@@ -89,29 +84,21 @@ jobs:
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.20.5'
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
override: true
|
||||
target: wasm32-unknown-unknown
|
||||
- run: |
|
||||
wget https://github.com/josStorer/ai00_rwkv_server/releases/latest/download/webgpu_server_linux_x86_64 -O ./backend-rust/webgpu_server
|
||||
wget https://github.com/josStorer/web-rwkv-converter/releases/latest/download/web-rwkv-converter_linux_x86_64 -O ./backend-rust/web-rwkv-converter
|
||||
sudo apt-get update
|
||||
sudo apt-get install upx
|
||||
sudo apt-get install build-essential libgtk-3-dev libwebkit2gtk-4.0-dev
|
||||
git clone https://github.com/josStorer/ai00_rwkv_server --depth=1
|
||||
cd ai00_rwkv_server
|
||||
sudo apt-get install libudev-dev
|
||||
sudo apt-get install libasound2-dev
|
||||
rustup target add x86_64-unknown-linux-gnu
|
||||
cargo build --release --target x86_64-unknown-linux-gnu
|
||||
mv ./target/x86_64-unknown-linux-gnu/release/ai00_server ../backend-rust/webgpu_server
|
||||
cd ..
|
||||
sudo apt-get install build-essential libgtk-3-dev libwebkit2gtk-4.0-dev libasound2-dev
|
||||
go install github.com/wailsapp/wails/v2/cmd/wails@latest
|
||||
rm ./backend-python/rwkv_pip/wkv_cuda.pyd
|
||||
rm ./backend-python/rwkv_pip/rwkv5.pyd
|
||||
rm ./backend-python/rwkv_pip/rwkv6.pyd
|
||||
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
|
||||
rm ./backend-python/get-pip.py
|
||||
rm ./backend-python/rwkv_pip/cpp/librwkv.dylib
|
||||
rm ./backend-python/rwkv_pip/cpp/rwkv.dll
|
||||
rm ./backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd
|
||||
make
|
||||
mv build/bin/RWKV-Runner build/bin/RWKV-Runner_linux_x64
|
||||
|
||||
@@ -127,24 +114,18 @@ jobs:
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '1.20.5'
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
override: true
|
||||
target: wasm32-unknown-unknown
|
||||
- run: |
|
||||
git clone https://github.com/josStorer/ai00_rwkv_server --depth=1
|
||||
cd ai00_rwkv_server
|
||||
rustup target add aarch64-apple-darwin
|
||||
cargo build --release --target aarch64-apple-darwin
|
||||
mv ./target/aarch64-apple-darwin/release/ai00_server ../backend-rust/webgpu_server
|
||||
cd ..
|
||||
wget https://github.com/josStorer/ai00_rwkv_server/releases/latest/download/webgpu_server_darwin_aarch64 -O ./backend-rust/webgpu_server
|
||||
wget https://github.com/josStorer/web-rwkv-converter/releases/latest/download/web-rwkv-converter_darwin_aarch64 -O ./backend-rust/web-rwkv-converter
|
||||
go install github.com/wailsapp/wails/v2/cmd/wails@latest
|
||||
rm ./backend-python/rwkv_pip/wkv_cuda.pyd
|
||||
rm ./backend-python/rwkv_pip/rwkv5.pyd
|
||||
rm ./backend-python/rwkv_pip/rwkv6.pyd
|
||||
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
|
||||
rm ./backend-python/get-pip.py
|
||||
rm ./backend-python/rwkv_pip/cpp/rwkv.dll
|
||||
rm ./backend-python/rwkv_pip/cpp/librwkv.so
|
||||
rm ./backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd
|
||||
make
|
||||
cp build/darwin/Readme_Install.txt build/bin/Readme_Install.txt
|
||||
cp build/bin/RWKV-Runner.app/Contents/MacOS/RWKV-Runner build/bin/RWKV-Runner_darwin_universal
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,6 +8,7 @@ __pycache__
|
||||
*.st
|
||||
*.safetensors
|
||||
*.bin
|
||||
*.mid
|
||||
/config.json
|
||||
/cache.json
|
||||
/presets.json
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
## Changes
|
||||
|
||||
- MIDI Input Audio Tracks (Experimental, playing tracks is not supported yet, please save to generation area to preview)
|
||||
- fix autoPlayed midi cannot be stopped
|
||||
- try to use local soundfont by default
|
||||
- improve details
|
||||
- abc music inference support
|
||||
- basic abc frontend support
|
||||
- fix finetune errorsMap ($modelInfo)
|
||||
|
||||
## Install
|
||||
|
||||
- Windows: https://github.com/josStorer/RWKV-Runner/blob/master/build/windows/Readme_Install.txt
|
||||
- MacOS: https://github.com/josStorer/RWKV-Runner/blob/master/build/darwin/Readme_Install.txt
|
||||
- Linux: https://github.com/josStorer/RWKV-Runner/blob/master/build/linux/Readme_Install.txt
|
||||
- Server-Deploy-Examples: https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples
|
||||
|
||||
#### MIDI Input Audio Tracks
|
||||
|
||||

|
||||
- Simple Deploy Example: https://github.com/josStorer/RWKV-Runner/blob/master/README.md#simple-deploy-example
|
||||
- Server Deploy Examples: https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples
|
||||
|
||||
6
Makefile
6
Makefile
@@ -8,15 +8,15 @@ endif
|
||||
|
||||
build-windows:
|
||||
@echo ---- build for windows
|
||||
wails build -upx -ldflags "-s -w" -platform windows/amd64
|
||||
wails build -upx -ldflags '-s -w -extldflags "-static"' -platform windows/amd64
|
||||
|
||||
build-macos:
|
||||
@echo ---- build for macos
|
||||
wails build -ldflags "-s -w" -platform darwin/universal
|
||||
wails build -ldflags '-s -w' -platform darwin/universal
|
||||
|
||||
build-linux:
|
||||
@echo ---- build for linux
|
||||
wails build -upx -ldflags "-s -w" -platform linux/amd64
|
||||
wails build -upx -ldflags '-s -w' -platform linux/amd64
|
||||
|
||||
build-web:
|
||||
@echo ---- build for web
|
||||
|
||||
144
README.md
144
README.md
@@ -21,7 +21,7 @@ English | [简体中文](README_ZH.md) | [日本語](README_JA.md)
|
||||
[![MacOS][MacOS-image]][MacOS-url]
|
||||
[![Linux][Linux-image]][Linux-url]
|
||||
|
||||
[FAQs](https://github.com/josStorer/RWKV-Runner/wiki/FAQs) | [Preview](#Preview) | [Download][download-url] | [Server-Deploy-Examples](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
|
||||
[FAQs](https://github.com/josStorer/RWKV-Runner/wiki/FAQs) | [Preview](#Preview) | [Download][download-url] | [Simple Deploy Example](#Simple-Deploy-Example) | [Server Deploy Examples](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples) | [MIDI Hardware Input](#MIDI-Input)
|
||||
|
||||
[license-image]: http://img.shields.io/badge/license-MIT-blue.svg
|
||||
|
||||
@@ -47,30 +47,74 @@ English | [简体中文](README_ZH.md) | [日本語](README_JA.md)
|
||||
|
||||
</div>
|
||||
|
||||
#### Tip: You can deploy [backend-python](./backend-python/) on a server and use this program as a client only. Fill in your server address in the Settings `API URL`.
|
||||
## Tips
|
||||
|
||||
#### Default configs has enabled custom CUDA kernel acceleration, which is much faster and consumes much less VRAM. If you encounter possible compatibility issues (output garbled), go to the Configs page and turn off `Use Custom CUDA kernel to Accelerate`, or try to upgrade your gpu driver.
|
||||
- You can deploy [backend-python](./backend-python/) on a server and use this program as a client only. Fill in
|
||||
your server address in the Settings `API URL`.
|
||||
|
||||
#### If Windows Defender claims this is a virus, you can try downloading [v1.3.7_win.zip](https://github.com/josStorer/RWKV-Runner/releases/download/v1.3.7/RWKV-Runner_win.zip) and letting it update automatically to the latest version, or add it to the trusted list (`Windows Security` -> `Virus & threat protection` -> `Manage settings` -> `Exclusions` -> `Add or remove exclusions` -> `Add an exclusion` -> `Folder` -> `RWKV-Runner`).
|
||||
- If you are deploying and providing public services, please limit the request size through API gateway to prevent
|
||||
excessive resource usage caused by submitting overly long prompts. Additionally, please restrict the upper limit of
|
||||
requests' max_tokens based on your actual
|
||||
situation: https://github.com/josStorer/RWKV-Runner/blob/master/backend-python/utils/rwkv.py#L567, the default is set
|
||||
as le=102400, which may result in significant resource consumption for individual responses in extreme cases.
|
||||
|
||||
#### For different tasks, adjusting API parameters can achieve better results. For example, for translation tasks, you can try setting Temperature to 1 and Top_P to 0.3.
|
||||
- Default configs has enabled custom CUDA kernel acceleration, which is much faster and consumes much less VRAM. If you
|
||||
encounter possible compatibility issues (output garbled), go to the Configs page and turn
|
||||
off `Use Custom CUDA kernel to Accelerate`, or try to upgrade your gpu driver.
|
||||
|
||||
- If Windows Defender claims this is a virus, you can try
|
||||
downloading [v1.3.7_win.zip](https://github.com/josStorer/RWKV-Runner/releases/download/v1.3.7/RWKV-Runner_win.zip)
|
||||
and letting it update automatically to the latest version, or add it to the trusted
|
||||
list (`Windows Security` -> `Virus & threat protection` -> `Manage settings` -> `Exclusions` -> `Add or remove exclusions` -> `Add an exclusion` -> `Folder` -> `RWKV-Runner`).
|
||||
|
||||
- For different tasks, adjusting API parameters can achieve better results. For example, for translation tasks, you can
|
||||
try setting Temperature to 1 and Top_P to 0.3.
|
||||
|
||||
## Features
|
||||
|
||||
- RWKV model management and one-click startup
|
||||
- Fully compatible with the OpenAI API, making every ChatGPT client an RWKV client. After starting the model,
|
||||
- RWKV model management and one-click startup.
|
||||
- Front-end and back-end separation, if you don't want to use the client, also allows for separately deploying the
|
||||
front-end service, or the back-end inference service, or the back-end inference service with a WebUI.
|
||||
[Simple Deploy Example](#Simple-Deploy-Example) | [Server Deploy Examples](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
|
||||
- Compatible with the OpenAI API, making every ChatGPT client an RWKV client. After starting the model,
|
||||
open http://127.0.0.1:8000/docs to view more details.
|
||||
- Automatic dependency installation, requiring only a lightweight executable program
|
||||
- Configs with 2G to 32G VRAM are included, works well on almost all computers
|
||||
- User-friendly chat and completion interaction interface included
|
||||
- Easy-to-understand and operate parameter configuration
|
||||
- Built-in model conversion tool
|
||||
- Built-in download management and remote model inspection
|
||||
- Built-in one-click LoRA Finetune
|
||||
- Can also be used as an OpenAI ChatGPT and GPT-Playground client
|
||||
- Multilingual localization
|
||||
- Theme switching
|
||||
- Automatic updates
|
||||
- Automatic dependency installation, requiring only a lightweight executable program.
|
||||
- Pre-set multi-level VRAM configs, works well on almost all computers. In Configs page, switch Strategy to WebGPU, it
|
||||
can also run on AMD, Intel, and other graphics cards.
|
||||
- User-friendly chat, completion, and composition interaction interface included. Also supports chat presets, attachment
|
||||
uploads, MIDI hardware input, and track editing.
|
||||
[Preview](#Preview) | [MIDI Hardware Input](#MIDI-Input)
|
||||
- Built-in WebUI option, one-click start of Web service, sharing your hardware resources.
|
||||
- Easy-to-understand and operate parameter configuration, along with various operation guidance prompts.
|
||||
- Built-in model conversion tool.
|
||||
- Built-in download management and remote model inspection.
|
||||
- Built-in one-click LoRA Finetune. (Windows Only)
|
||||
- Can also be used as an OpenAI ChatGPT and GPT-Playground client. (Fill in the API URL and API Key in Settings page)
|
||||
- Multilingual localization.
|
||||
- Theme switching.
|
||||
- Automatic updates.
|
||||
|
||||
## Simple Deploy Example
|
||||
|
||||
```bash
|
||||
git clone https://github.com/josStorer/RWKV-Runner
|
||||
|
||||
# Then
|
||||
cd RWKV-Runner
|
||||
python ./backend-python/main.py #The backend inference service has been started, request /switch-model API to load the model, refer to the API documentation: http://127.0.0.1:8000/docs
|
||||
|
||||
# Or
|
||||
cd RWKV-Runner/frontend
|
||||
npm ci
|
||||
npm run build #Compile the frontend
|
||||
cd ..
|
||||
python ./backend-python/webui_server.py #Start the frontend service separately
|
||||
# Or
|
||||
python ./backend-python/main.py --webui #Start the frontend and backend service at the same time
|
||||
|
||||
# Help Info
|
||||
python ./backend-python/main.py -h
|
||||
```
|
||||
|
||||
## API Concurrency Stress Testing
|
||||
|
||||
@@ -133,40 +177,98 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
|
||||
print(f"{embeddings_cos_sim[i]:.10f} - {values[i]}")
|
||||
```
|
||||
|
||||
## MIDI Input
|
||||
|
||||
Tip: You can download https://github.com/josStorer/sgm_plus and unzip it to the program's `assets/sound-font` directory
|
||||
to use it as an offline sound source. Please note that if you are compiling the program from source code, do not place
|
||||
it in the source code directory.
|
||||
|
||||
If you don't have a MIDI keyboard, you can use virtual MIDI input software like `Virtual Midi Controller 3 LE`, along
|
||||
with [loopMIDI](https://www.tobias-erichsen.de/wp-content/uploads/2020/01/loopMIDISetup_1_0_16_27.zip), to use a regular
|
||||
computer keyboard as MIDI input.
|
||||
|
||||
### USB MIDI Connection
|
||||
|
||||
- USB MIDI devices are plug-and-play, and you can select your input device in the Composition page
|
||||
- 
|
||||
|
||||
### Mac MIDI Bluetooth Connection
|
||||
|
||||
- For Mac users who want to use Bluetooth input,
|
||||
please install [Bluetooth MIDI Connect](https://apps.apple.com/us/app/bluetooth-midi-connect/id1108321791), then click
|
||||
the tray icon to connect after launching,
|
||||
afterwards, you can select your input device in the Composition page.
|
||||
- 
|
||||
|
||||
### Windows MIDI Bluetooth Connection
|
||||
|
||||
- Windows seems to have implemented Bluetooth MIDI support only for UWP (Universal Windows Platform) apps. Therefore, it
|
||||
requires multiple steps to establish a connection. We need to create a local virtual MIDI device and then launch a UWP
|
||||
application. Through this UWP application, we will redirect Bluetooth MIDI input to the virtual MIDI device, and then
|
||||
this software will listen to the input from the virtual MIDI device.
|
||||
- So, first, you need to
|
||||
download [loopMIDI](https://www.tobias-erichsen.de/wp-content/uploads/2020/01/loopMIDISetup_1_0_16_27.zip)
|
||||
to create a virtual MIDI device. Click the plus sign in the bottom left corner to create the device.
|
||||
- 
|
||||
- Next, you need to download [Bluetooth LE Explorer](https://apps.microsoft.com/detail/9N0ZTKF1QD98) to discover and
|
||||
connect to Bluetooth MIDI devices. Click "Start" to search for devices, and then click "Pair" to bind the MIDI device.
|
||||
- 
|
||||
- Finally, you need to install [MIDIberry](https://apps.microsoft.com/detail/9N39720H2M05),
|
||||
This UWP application can redirect Bluetooth MIDI input to the virtual MIDI device. After launching it, double-click
|
||||
your actual Bluetooth MIDI device name in the input field, and in the output field, double-click the virtual MIDI
|
||||
device name we created earlier.
|
||||
- 
|
||||
- Now, you can select the virtual MIDI device as the input in the Composition page. Bluetooth LE Explorer no longer
|
||||
needs to run, and you can also close the loopMIDI window, it will run automatically in the background. Just keep
|
||||
MIDIberry open.
|
||||
- 
|
||||
|
||||
## Related Repositories:
|
||||
|
||||
- RWKV-5-World: https://huggingface.co/BlinkDL/rwkv-5-world/tree/main
|
||||
- RWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main
|
||||
- RWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main
|
||||
- ChatRWKV: https://github.com/BlinkDL/ChatRWKV
|
||||
- RWKV-LM: https://github.com/BlinkDL/RWKV-LM
|
||||
- RWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA
|
||||
- MIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer
|
||||
- ai00_rwkv_server: https://github.com/cgisky1980/ai00_rwkv_server
|
||||
- rwkv.cpp: https://github.com/saharNooby/rwkv.cpp
|
||||
- web-rwkv-py: https://github.com/cryscan/web-rwkv-py
|
||||
|
||||
## Preview
|
||||
|
||||
### Homepage
|
||||
|
||||

|
||||

|
||||
|
||||
### Chat
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Completion
|
||||
|
||||

|
||||
|
||||
### Composition
|
||||
|
||||
Tip: You can download https://github.com/josStorer/sgm_plus and unzip it to the program's `assets/sound-font` directory
|
||||
to use it as an offline sound source. Please note that if you are compiling the program from source code, do not place
|
||||
it in the source code directory.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Configuration
|
||||
|
||||

|
||||

|
||||
|
||||
### Model Management
|
||||
|
||||

|
||||

|
||||
|
||||
### Download Management
|
||||
|
||||
|
||||
128
README_JA.md
128
README_JA.md
@@ -21,7 +21,7 @@
|
||||
[![MacOS][MacOS-image]][MacOS-url]
|
||||
[![Linux][Linux-image]][Linux-url]
|
||||
|
||||
[FAQs](https://github.com/josStorer/RWKV-Runner/wiki/FAQs) | [プレビュー](#Preview) | [ダウンロード][download-url] | [サーバーデプロイ例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
|
||||
[FAQs](https://github.com/josStorer/RWKV-Runner/wiki/FAQs) | [プレビュー](#Preview) | [ダウンロード][download-url] | [シンプルなデプロイの例](#Simple-Deploy-Example) | [サーバーデプロイ例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples) | [MIDIハードウェア入力](#MIDI-Input)
|
||||
|
||||
[license-image]: http://img.shields.io/badge/license-MIT-blue.svg
|
||||
|
||||
@@ -47,31 +47,71 @@
|
||||
|
||||
</div>
|
||||
|
||||
#### ヒント:サーバーに[backend-python](./backend-python/)をデプロイし、このプログラムをクライアントとして使用することができます。設定された`API URL`にサーバーアドレスを入力してください。
|
||||
## ヒント
|
||||
|
||||
#### デフォルトの設定はカスタム CUDA カーネルアクセラレーションを有効にしています。互換性の問題 (文字化けを出力する) が発生する可能性がある場合は、コンフィグページに移動し、`Use Custom CUDA kernel to Accelerate` をオフにしてください、あるいは、GPUドライバーをアップグレードしてみてください。
|
||||
- サーバーに [backend-python](./backend-python/)
|
||||
をデプロイし、このプログラムをクライアントとして使用することができます。設定された`API URL`にサーバーアドレスを入力してください。
|
||||
|
||||
#### Windows Defender がこれをウイルスだと主張する場合は、[v1.3.7_win.zip](https://github.com/josStorer/RWKV-Runner/releases/download/v1.3.7/RWKV-Runner_win.zip) をダウンロードして最新版に自動更新させるか、信頼済みリストに追加してみてください (`Windows Security` -> `Virus & threat protection` -> `Manage settings` -> `Exclusions` -> `Add or remove exclusions` -> `Add an exclusion` -> `Folder` -> `RWKV-Runner`)。
|
||||
- もし、あなたがデプロイし、外部に公開するサービスを提供している場合、APIゲートウェイを使用してリクエストのサイズを制限し、
|
||||
長すぎるプロンプトの提出がリソースを占有しないようにしてください。さらに、実際の状況に応じて、リクエストの max_tokens
|
||||
の上限を制限してください:https://github.com/josStorer/RWKV-Runner/blob/master/backend-python/utils/rwkv.py#L567
|
||||
、デフォルトは le=102400 ですが、極端な場合には単一の応答が大量のリソースを消費する可能性があります。
|
||||
|
||||
#### 異なるタスクについては、API パラメータを調整することで、より良い結果を得ることができます。例えば、翻訳タスクの場合、Temperature を 1 に、Top_P を 0.3 に設定してみてください。
|
||||
- デフォルトの設定はカスタム CUDA カーネルアクセラレーションを有効にしています。互換性の問題 (文字化けを出力する)
|
||||
が発生する可能性がある場合は、コンフィグページに移動し、`Use Custom CUDA kernel to Accelerate`
|
||||
をオフにしてください、あるいは、GPUドライバーをアップグレードしてみてください。
|
||||
|
||||
- Windows Defender
|
||||
がこれをウイルスだと主張する場合は、[v1.3.7_win.zip](https://github.com/josStorer/RWKV-Runner/releases/download/v1.3.7/RWKV-Runner_win.zip)
|
||||
をダウンロードして最新版に自動更新させるか、信頼済みリストに追加してみてください (`Windows Security` -> `Virus & threat protection` -> `Manage settings` -> `Exclusions` -> `Add or remove exclusions` -> `Add an exclusion` -> `Folder` -> `RWKV-Runner`)。
|
||||
|
||||
- 異なるタスクについては、API パラメータを調整することで、より良い結果を得ることができます。例えば、翻訳タスクの場合、Temperature
|
||||
を 1 に、Top_P を 0.3 に設定してみてください。
|
||||
|
||||
## 特徴
|
||||
|
||||
- RWKV モデル管理とワンクリック起動
|
||||
- OpenAI API と完全に互換性があり、すべての ChatGPT クライアントを RWKV クライアントにします。モデル起動後、
|
||||
- フロントエンドとバックエンドの分離は、クライアントを使用しない場合でも、フロントエンドサービス、またはバックエンド推論サービス、またはWebUIを備えたバックエンド推論サービスを個別に展開することを可能にします。
|
||||
[シンプルなデプロイの例](#Simple-Deploy-Example) | [サーバーデプロイ例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
|
||||
- OpenAI API と互換性があり、すべての ChatGPT クライアントを RWKV クライアントにします。モデル起動後、
|
||||
http://127.0.0.1:8000/docs を開いて詳細をご覧ください。
|
||||
- 依存関係の自動インストールにより、軽量な実行プログラムのみを必要とします
|
||||
- 2G から 32G の VRAM のコンフィグが含まれており、ほとんどのコンピュータで動作します
|
||||
- ユーザーフレンドリーなチャットと完成インタラクションインターフェースを搭載
|
||||
- 分かりやすく操作しやすいパラメータ設定
|
||||
- 事前設定された多段階のVRAM設定、ほとんどのコンピュータで動作します。配置ページで、ストラテジーをWebGPUに切り替えると、AMD、インテル、その他のグラフィックカードでも動作します
|
||||
- ユーザーフレンドリーなチャット、完成、および作曲インターフェイスが含まれています。また、チャットプリセット、添付ファイルのアップロード、MIDIハードウェア入力、トラック編集もサポートしています。
|
||||
[プレビュー](#Preview) | [MIDIハードウェア入力](#MIDI-Input)
|
||||
- 内蔵WebUIオプション、Webサービスのワンクリック開始、ハードウェアリソースの共有
|
||||
- 分かりやすく操作しやすいパラメータ設定、各種操作ガイダンスプロンプトとともに
|
||||
- 内蔵モデル変換ツール
|
||||
- ダウンロード管理とリモートモデル検査機能内蔵
|
||||
- 内蔵のLoRA微調整機能を搭載しています
|
||||
- このプログラムは、OpenAI ChatGPTとGPT Playgroundのクライアントとしても使用できます
|
||||
- 内蔵のLoRA微調整機能を搭載しています (Windowsのみ)
|
||||
- このプログラムは、OpenAI ChatGPTとGPT Playgroundのクライアントとしても使用できます(設定ページで `API URL` と `API Key`
|
||||
を入力してください)
|
||||
- 多言語ローカライズ
|
||||
- テーマ切り替え
|
||||
- 自動アップデート
|
||||
|
||||
## Simple Deploy Example
|
||||
|
||||
```bash
|
||||
git clone https://github.com/josStorer/RWKV-Runner
|
||||
|
||||
# Then
|
||||
cd RWKV-Runner
|
||||
python ./backend-python/main.py #The backend inference service has been started, request /switch-model API to load the model, refer to the API documentation: http://127.0.0.1:8000/docs
|
||||
|
||||
# Or
|
||||
cd RWKV-Runner/frontend
|
||||
npm ci
|
||||
npm run build #Compile the frontend
|
||||
cd ..
|
||||
python ./backend-python/webui_server.py #Start the frontend service separately
|
||||
# Or
|
||||
python ./backend-python/main.py --webui #Start the frontend and backend service at the same time
|
||||
|
||||
# Help Info
|
||||
python ./backend-python/main.py -h
|
||||
```
|
||||
|
||||
## API 同時実行ストレステスト
|
||||
|
||||
```bash
|
||||
@@ -134,40 +174,98 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
|
||||
print(f"{embeddings_cos_sim[i]:.10f} - {values[i]}")
|
||||
```
|
||||
|
||||
## MIDI Input
|
||||
|
||||
Tip: You can download https://github.com/josStorer/sgm_plus and unzip it to the program's `assets/sound-font` directory
|
||||
to use it as an offline sound source. Please note that if you are compiling the program from source code, do not place
|
||||
it in the source code directory.
|
||||
|
||||
MIDIキーボードをお持ちでない場合、`Virtual Midi Controller 3 LE`
|
||||
などの仮想MIDI入力ソフトウェアを使用することができます。[loopMIDI](https://www.tobias-erichsen.de/wp-content/uploads/2020/01/loopMIDISetup_1_0_16_27.zip)
|
||||
を組み合わせて、通常のコンピュータキーボードをMIDI入力として使用できます。
|
||||
|
||||
### USB MIDI Connection
|
||||
|
||||
- USB MIDI devices are plug-and-play, and you can select your input device in the Composition page
|
||||
- 
|
||||
|
||||
### Mac MIDI Bluetooth Connection
|
||||
|
||||
- For Mac users who want to use Bluetooth input,
|
||||
please install [Bluetooth MIDI Connect](https://apps.apple.com/us/app/bluetooth-midi-connect/id1108321791), then click
|
||||
the tray icon to connect after launching,
|
||||
afterwards, you can select your input device in the Composition page.
|
||||
- 
|
||||
|
||||
### Windows MIDI Bluetooth Connection
|
||||
|
||||
- Windows seems to have implemented Bluetooth MIDI support only for UWP (Universal Windows Platform) apps. Therefore, it
|
||||
requires multiple steps to establish a connection. We need to create a local virtual MIDI device and then launch a UWP
|
||||
application. Through this UWP application, we will redirect Bluetooth MIDI input to the virtual MIDI device, and then
|
||||
this software will listen to the input from the virtual MIDI device.
|
||||
- So, first, you need to
|
||||
download [loopMIDI](https://www.tobias-erichsen.de/wp-content/uploads/2020/01/loopMIDISetup_1_0_16_27.zip)
|
||||
to create a virtual MIDI device. Click the plus sign in the bottom left corner to create the device.
|
||||
- 
|
||||
- Next, you need to download [Bluetooth LE Explorer](https://apps.microsoft.com/detail/9N0ZTKF1QD98) to discover and
|
||||
connect to Bluetooth MIDI devices. Click "Start" to search for devices, and then click "Pair" to bind the MIDI device.
|
||||
- 
|
||||
- Finally, you need to install [MIDIberry](https://apps.microsoft.com/detail/9N39720H2M05),
|
||||
This UWP application can redirect Bluetooth MIDI input to the virtual MIDI device. After launching it, double-click
|
||||
your actual Bluetooth MIDI device name in the input field, and in the output field, double-click the virtual MIDI
|
||||
device name we created earlier.
|
||||
- 
|
||||
- Now, you can select the virtual MIDI device as the input in the Composition page. Bluetooth LE Explorer no longer
|
||||
needs to run, and you can also close the loopMIDI window, it will run automatically in the background. Just keep
|
||||
MIDIberry open.
|
||||
- 
|
||||
|
||||
## 関連リポジトリ:
|
||||
|
||||
- RWKV-5-World: https://huggingface.co/BlinkDL/rwkv-5-world/tree/main
|
||||
- RWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main
|
||||
- RWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main
|
||||
- ChatRWKV: https://github.com/BlinkDL/ChatRWKV
|
||||
- RWKV-LM: https://github.com/BlinkDL/RWKV-LM
|
||||
- RWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA
|
||||
- MIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer
|
||||
- ai00_rwkv_server: https://github.com/cgisky1980/ai00_rwkv_server
|
||||
- rwkv.cpp: https://github.com/saharNooby/rwkv.cpp
|
||||
- web-rwkv-py: https://github.com/cryscan/web-rwkv-py
|
||||
|
||||
## プレビュー
|
||||
## Preview
|
||||
|
||||
### ホームページ
|
||||
|
||||

|
||||

|
||||
|
||||
### チャット
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### 補完
|
||||
|
||||

|
||||
|
||||
### 作曲
|
||||
|
||||
Tip: You can download https://github.com/josStorer/sgm_plus and unzip it to the program's `assets/sound-font` directory
|
||||
to use it as an offline sound source. Please note that if you are compiling the program from source code, do not place
|
||||
it in the source code directory.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### コンフィグ
|
||||
|
||||

|
||||

|
||||
|
||||
### モデル管理
|
||||
|
||||

|
||||

|
||||
|
||||
### ダウンロード管理
|
||||
|
||||
|
||||
111
README_ZH.md
111
README_ZH.md
@@ -20,7 +20,7 @@ API兼容的接口,这意味着一切ChatGPT客户端都是RWKV客户端。
|
||||
[![MacOS][MacOS-image]][MacOS-url]
|
||||
[![Linux][Linux-image]][Linux-url]
|
||||
|
||||
[视频演示](https://www.bilibili.com/video/BV1hM4y1v76R) | [疑难解答](https://www.bilibili.com/read/cv23921171) | [预览](#Preview) | [下载][download-url] | [懒人包](https://pan.baidu.com/s/1zdzZ_a0uM3gDqi6pXIZVAA?pwd=1111) | [服务器部署示例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
|
||||
[视频演示](https://www.bilibili.com/video/BV1hM4y1v76R) | [疑难解答](https://www.bilibili.com/read/cv23921171) | [预览](#Preview) | [下载][download-url] | [懒人包](https://pan.baidu.com/s/1zdzZ_a0uM3gDqi6pXIZVAA?pwd=1111) | [简明服务部署示例](#Simple-Deploy-Example) | [服务器部署示例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples) | [MIDI硬件输入](#MIDI-Input)
|
||||
|
||||
[license-image]: http://img.shields.io/badge/license-MIT-blue.svg
|
||||
|
||||
@@ -46,30 +46,65 @@ API兼容的接口,这意味着一切ChatGPT客户端都是RWKV客户端。
|
||||
|
||||
</div>
|
||||
|
||||
#### 小贴士:你可以在服务器部署[backend-python](./backend-python/),然后将此程序仅用作客户端,在设置的`API URL`中填入你的服务器地址
|
||||
## 小贴士
|
||||
|
||||
#### 预设配置已经开启自定义CUDA算子加速,速度更快,且显存消耗更少。如果你遇到可能的兼容性(输出乱码)问题,前往配置页面,关闭`使用自定义CUDA算子加速`,或更新你的显卡驱动
|
||||
- 你可以在服务器部署[backend-python](./backend-python/),然后将此程序仅用作客户端,在设置的`API URL`中填入你的服务器地址
|
||||
|
||||
#### 如果Windows Defender说这是一个病毒,你可以尝试下载[v1.3.7_win.zip](https://github.com/josStorer/RWKV-Runner/releases/download/v1.3.7/RWKV-Runner_win.zip),然后让其自动更新到最新版,或添加信任 (`Windows Security` -> `Virus & threat protection` -> `Manage settings` -> `Exclusions` -> `Add or remove exclusions` -> `Add an exclusion` -> `Folder` -> `RWKV-Runner`)
|
||||
- 如果你正在部署并对外提供公开服务,请通过API网关限制请求大小,避免过长的prompt提交占用资源。此外,请根据你的实际情况,限制请求的
|
||||
max_tokens 上限: https://github.com/josStorer/RWKV-Runner/blob/master/backend-python/utils/rwkv.py#L567,
|
||||
默认le=102400, 这可能导致极端情况下单个响应消耗大量资源
|
||||
|
||||
#### 对于不同的任务,调整API参数会获得更好的效果,例如对于翻译任务,你可以尝试设置Temperature为1,Top_P为0.3
|
||||
- 预设配置已经开启自定义CUDA算子加速,速度更快,且显存消耗更少。如果你遇到可能的兼容性(输出乱码)
|
||||
问题,前往配置页面,关闭`使用自定义CUDA算子加速`,或更新你的显卡驱动
|
||||
|
||||
- 如果 Windows Defender
|
||||
说这是一个病毒,你可以尝试下载[v1.3.7_win.zip](https://github.com/josStorer/RWKV-Runner/releases/download/v1.3.7/RWKV-Runner_win.zip),
|
||||
然后让其自动更新到最新版,或添加信任 (`Windows Security` -> `Virus & threat protection` -> `Manage settings` -> `Exclusions` -> `Add or remove exclusions` -> `Add an exclusion` -> `Folder` -> `RWKV-Runner`)
|
||||
|
||||
- 对于不同的任务,调整API参数会获得更好的效果,例如对于翻译任务,你可以尝试设置Temperature为1,Top_P为0.3
|
||||
|
||||
## 功能
|
||||
|
||||
- RWKV模型管理,一键启动
|
||||
- 与OpenAI API完全兼容,一切ChatGPT客户端,都是RWKV客户端。启动模型后,打开 http://127.0.0.1:8000/docs 查看详细内容
|
||||
- 前后端分离,如果你不想使用客户端,也允许单独部署前端服务,或后端推理服务,或具有WebUI的后端推理服务。
|
||||
[简明服务部署示例](#Simple-Deploy-Example) | [服务器部署示例](https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples)
|
||||
- 与OpenAI API兼容,一切ChatGPT客户端,都是RWKV客户端。启动模型后,打开 http://127.0.0.1:8000/docs 查看API文档
|
||||
- 全自动依赖安装,你只需要一个轻巧的可执行程序
|
||||
- 预设了2G至32G显存的配置,几乎在各种电脑上工作良好
|
||||
- 自带用户友好的聊天和续写交互页面
|
||||
- 易于理解和操作的参数配置
|
||||
- 预设多级显存配置,几乎在各种电脑上工作良好。通过配置页面切换Strategy到WebGPU,还可以在AMD,Intel等显卡上运行
|
||||
- 自带用户友好的聊天,续写,作曲交互页面。支持聊天预设,附件上传,MIDI硬件输入及音轨编辑。
|
||||
[预览](#Preview) | [MIDI硬件输入](#MIDI-Input)
|
||||
- 内置WebUI选项,一键启动Web服务,共享硬件资源
|
||||
- 易于理解和操作的参数配置,及各类操作引导提示
|
||||
- 内置模型转换工具
|
||||
- 内置下载管理和远程模型检视
|
||||
- 内置一键LoRA微调
|
||||
- 也可用作 OpenAI ChatGPT 和 GPT Playground 客户端
|
||||
- 内置一键LoRA微调 (仅限Windows)
|
||||
- 也可用作 OpenAI ChatGPT 和 GPT Playground 客户端 (在设置内填写API URL和API Key)
|
||||
- 多语言本地化
|
||||
- 主题切换
|
||||
- 自动更新
|
||||
|
||||
## Simple Deploy Example
|
||||
|
||||
```bash
|
||||
git clone https://github.com/josStorer/RWKV-Runner
|
||||
|
||||
# 然后
|
||||
cd RWKV-Runner
|
||||
python ./backend-python/main.py #后端推理服务已启动, 调用/switch-model载入模型, 参考API文档: http://127.0.0.1:8000/docs
|
||||
|
||||
# 或者
|
||||
cd RWKV-Runner/frontend
|
||||
npm ci
|
||||
npm run build #编译前端
|
||||
cd ..
|
||||
python ./backend-python/webui_server.py #单独启动前端服务
|
||||
# 或者
|
||||
python ./backend-python/main.py --webui #同时启动前后端服务
|
||||
|
||||
# 帮助参数
|
||||
python ./backend-python/main.py -h
|
||||
```
|
||||
|
||||
## API并发压力测试
|
||||
|
||||
```bash
|
||||
@@ -130,40 +165,88 @@ for i in np.argsort(embeddings_cos_sim)[::-1]:
|
||||
print(f"{embeddings_cos_sim[i]:.10f} - {values[i]}")
|
||||
```
|
||||
|
||||
## MIDI Input
|
||||
|
||||
小贴士: 你可以下载 https://github.com/josStorer/sgm_plus, 并解压到程序的`assets/sound-font`目录, 以使用离线音源. 注意,
|
||||
如果你正在从源码编译程序, 请不要将其放置在源码目录中
|
||||
|
||||
如果你没有MIDI键盘, 你可以使用像 `Virtual Midi Controller 3 LE` 这样的虚拟MIDI输入软件,
|
||||
配合[loopMIDI](https://www.tobias-erichsen.de/wp-content/uploads/2020/01/loopMIDISetup_1_0_16_27.zip), 使用普通电脑键盘作为MIDI输入
|
||||
|
||||
### USB MIDI 连接
|
||||
|
||||
- USB MIDI设备是即插即用的, 你能够在作曲页面选择你的输入设备
|
||||
- 
|
||||
|
||||
### Mac MIDI 蓝牙连接
|
||||
|
||||
- 对于想要使用蓝牙输入的Mac用户,
|
||||
请安装[Bluetooth MIDI Connect](https://apps.apple.com/us/app/bluetooth-midi-connect/id1108321791), 启动后点击托盘连接,
|
||||
之后你可以在作曲页面选择你的输入设备
|
||||
- 
|
||||
|
||||
### Windows MIDI 蓝牙连接
|
||||
|
||||
- Windows似乎只为UWP实现了蓝牙MIDI支持, 因此需要多个步骤进行连接, 我们需要创建一个本地的虚拟MIDI设备, 然后启动一个UWP应用,
|
||||
通过此UWP应用将蓝牙MIDI输入重定向到虚拟MIDI设备, 然后本软件监听虚拟MIDI设备的输入
|
||||
- 因此, 首先你需要下载[loopMIDI](https://www.tobias-erichsen.de/wp-content/uploads/2020/01/loopMIDISetup_1_0_16_27.zip),
|
||||
用于创建虚拟MIDI设备, 点击左下角的加号创建设备
|
||||
- 
|
||||
- 然后, 你需要下载[Bluetooth LE Explorer](https://apps.microsoft.com/detail/9N0ZTKF1QD98), 以发现并连接蓝牙MIDI设备,
|
||||
点击Start搜索设备, 然后点击Pair绑定MIDI设备
|
||||
- 
|
||||
- 最后, 你需要安装[MIDIberry](https://apps.microsoft.com/detail/9N39720H2M05), 这个UWP应用能将MIDI蓝牙输入重定向到虚拟MIDI设备,
|
||||
启动后, 在输入栏, 双击你实际的蓝牙MIDI设备名称, 在输出栏, 双击我们先前创建的虚拟MIDI设备名称
|
||||
- 
|
||||
- 现在, 你可以在作曲页面选择虚拟MIDI设备作为输入. Bluetooth LE Explorer不再需要运行, loopMIDI窗口也可以退出, 它会自动在后台运行,
|
||||
仅保持MIDIberry打开即可
|
||||
- 
|
||||
|
||||
## 相关仓库:
|
||||
|
||||
- RWKV-5-World: https://huggingface.co/BlinkDL/rwkv-5-world/tree/main
|
||||
- RWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main
|
||||
- RWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main
|
||||
- ChatRWKV: https://github.com/BlinkDL/ChatRWKV
|
||||
- RWKV-LM: https://github.com/BlinkDL/RWKV-LM
|
||||
- RWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA
|
||||
- MIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer
|
||||
- ai00_rwkv_server: https://github.com/cgisky1980/ai00_rwkv_server
|
||||
- rwkv.cpp: https://github.com/saharNooby/rwkv.cpp
|
||||
- web-rwkv-py: https://github.com/cryscan/web-rwkv-py
|
||||
|
||||
## Preview
|
||||
|
||||
### 主页
|
||||
|
||||

|
||||

|
||||
|
||||
### 聊天
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### 续写
|
||||
|
||||

|
||||
|
||||
### 作曲
|
||||
|
||||
小贴士: 你可以下载 https://github.com/josStorer/sgm_plus, 并解压到程序的`assets/sound-font`目录, 以使用离线音源. 注意,
|
||||
如果你正在从源码编译程序, 请不要将其放置在源码目录中
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### 配置
|
||||
|
||||

|
||||

|
||||
|
||||
### 模型管理
|
||||
|
||||

|
||||

|
||||
|
||||
### 下载管理
|
||||
|
||||
|
||||
@@ -46,12 +46,16 @@ func (a *App) OnStartup(ctx context.Context) {
|
||||
}
|
||||
|
||||
os.Chmod(a.exDir+"backend-rust/webgpu_server", 0777)
|
||||
os.Chmod(a.exDir+"backend-rust/web-rwkv-converter", 0777)
|
||||
os.Mkdir(a.exDir+"models", os.ModePerm)
|
||||
os.Mkdir(a.exDir+"lora-models", os.ModePerm)
|
||||
os.Mkdir(a.exDir+"finetune/json2binidx_tool/data", os.ModePerm)
|
||||
f, err := os.Create(a.exDir + "lora-models/train_log.txt")
|
||||
if err == nil {
|
||||
f.Close()
|
||||
trainLogPath := a.exDir + "lora-models/train_log.txt"
|
||||
if !a.FileExists(trainLogPath) {
|
||||
f, err := os.Create(trainLogPath)
|
||||
if err == nil {
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
|
||||
a.downloadLoop()
|
||||
|
||||
@@ -14,6 +14,13 @@ import (
|
||||
wruntime "github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
)
|
||||
|
||||
func (a *App) SaveFile(path string, savedContent []byte) error {
|
||||
if err := os.WriteFile(a.exDir+path, savedContent, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) SaveJson(fileName string, jsonData any) error {
|
||||
text, err := json.MarshalIndent(jsonData, "", " ")
|
||||
if err != nil {
|
||||
@@ -195,3 +202,12 @@ func (a *App) OpenFileFolder(path string, relative bool) error {
|
||||
}
|
||||
return errors.New("unsupported OS")
|
||||
}
|
||||
|
||||
func (a *App) StartFile(path string) error {
|
||||
cmd, err := CmdHelper(true, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = cmd.Start()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -86,6 +86,12 @@ func (a *App) OpenMidiPort(index int) error {
|
||||
channel := bytes[0] & 0x0f
|
||||
switch msgType {
|
||||
case 0x8:
|
||||
elapsed := time.Since(lastNoteTime)
|
||||
lastNoteTime = time.Now()
|
||||
runtime.EventsEmit(a.ctx, "midiMessage", &MIDIMessage{
|
||||
MessageType: "ElapsedTime",
|
||||
Value: int(elapsed.Milliseconds()),
|
||||
})
|
||||
note := bytes[1]
|
||||
runtime.EventsEmit(a.ctx, "midiMessage", &MIDIMessage{
|
||||
MessageType: "NoteOff",
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool) (string, error) {
|
||||
func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool, rwkvcpp bool, webgpu bool) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
python, err = GetPython()
|
||||
@@ -25,6 +25,12 @@ func (a *App) StartServer(python string, port int, host string, webui bool, rwkv
|
||||
if rwkvBeta {
|
||||
args = append(args, "--rwkv-beta")
|
||||
}
|
||||
if rwkvcpp {
|
||||
args = append(args, "--rwkv.cpp")
|
||||
}
|
||||
if webgpu {
|
||||
args = append(args, "--webgpu")
|
||||
}
|
||||
args = append(args, "--port", strconv.Itoa(port), "--host", host)
|
||||
return Cmd(args...)
|
||||
}
|
||||
@@ -46,7 +52,13 @@ func (a *App) ConvertModel(python string, modelPath string, strategy string, out
|
||||
return Cmd(python, "./backend-python/convert_model.py", "--in", modelPath, "--out", outPath, "--strategy", strategy)
|
||||
}
|
||||
|
||||
func (a *App) ConvertSafetensors(python string, modelPath string, outPath string) (string, error) {
|
||||
func (a *App) ConvertSafetensors(modelPath string, outPath string) (string, error) {
|
||||
args := []string{"./backend-rust/web-rwkv-converter"}
|
||||
args = append(args, "--input", modelPath, "--output", outPath)
|
||||
return Cmd(args...)
|
||||
}
|
||||
|
||||
func (a *App) ConvertSafetensorsWithPython(python string, modelPath string, outPath string) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
python, err = GetPython()
|
||||
@@ -57,6 +69,21 @@ func (a *App) ConvertSafetensors(python string, modelPath string, outPath string
|
||||
return Cmd(python, "./backend-python/convert_safetensors.py", "--input", modelPath, "--output", outPath)
|
||||
}
|
||||
|
||||
func (a *App) ConvertGGML(python string, modelPath string, outPath string, Q51 bool) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
python, err = GetPython()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dataType := "FP16"
|
||||
if Q51 {
|
||||
dataType = "Q5_1"
|
||||
}
|
||||
return Cmd(python, "./backend-python/convert_pytorch_to_ggml.py", modelPath, outPath, dataType)
|
||||
}
|
||||
|
||||
func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
|
||||
@@ -3,6 +3,7 @@ package backend_golang
|
||||
import (
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"crypto/sha256"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,33 +16,50 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func CmdHelper(hideWindow bool, args ...string) (*exec.Cmd, error) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil, errors.New("unsupported OS")
|
||||
}
|
||||
filename := "./cmd-helper.bat"
|
||||
_, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
if err := os.WriteFile(filename, []byte("start %*"), 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
cmdHelper, err := filepath.Abs(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.Contains(cmdHelper, " ") {
|
||||
for _, arg := range args {
|
||||
if strings.Contains(arg, " ") {
|
||||
return nil, errors.New("path contains space") // golang bug https://github.com/golang/go/issues/17149#issuecomment-473976818
|
||||
}
|
||||
}
|
||||
}
|
||||
cmd := exec.Command(cmdHelper, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
//go:custom_build windows cmd.SysProcAttr.HideWindow = hideWindow
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func Cmd(args ...string) (string, error) {
|
||||
switch platform := runtime.GOOS; platform {
|
||||
case "windows":
|
||||
if err := os.WriteFile("./cmd-helper.bat", []byte("start %*"), 0644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
cmdHelper, err := filepath.Abs("./cmd-helper")
|
||||
cmd, err := CmdHelper(true, args...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if strings.Contains(cmdHelper, " ") {
|
||||
for _, arg := range args {
|
||||
if strings.Contains(arg, " ") {
|
||||
return "", errors.New("path contains space") // golang bug https://github.com/golang/go/issues/17149#issuecomment-473976818
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cmd := exec.Command(cmdHelper, args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
_, err = cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(out), nil
|
||||
return "", nil
|
||||
case "darwin":
|
||||
ex, err := os.Executable()
|
||||
if err != nil {
|
||||
@@ -95,9 +113,19 @@ func CopyEmbed(efs embed.FS) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.WriteFile(path, content, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
executeWrite := true
|
||||
existedContent, err := os.ReadFile(path)
|
||||
if err == nil {
|
||||
if fmt.Sprintf("%x", sha256.Sum256(existedContent)) == fmt.Sprintf("%x", sha256.Sum256(content)) {
|
||||
executeWrite = false
|
||||
}
|
||||
}
|
||||
|
||||
if executeWrite {
|
||||
err = os.WriteFile(path, content, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
169
backend-python/convert_pytorch_to_ggml.py
vendored
Normal file
169
backend-python/convert_pytorch_to_ggml.py
vendored
Normal file
@@ -0,0 +1,169 @@
|
||||
# Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file.
|
||||
# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M-FP16.bin FP16
|
||||
# Get model checkpoints from https://huggingface.co/BlinkDL
|
||||
# See FILE_FORMAT.md for the documentation on the file format.
|
||||
|
||||
import argparse
|
||||
import struct
|
||||
import torch
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file"
|
||||
)
|
||||
parser.add_argument("src_path", help="Path to PyTorch checkpoint file")
|
||||
parser.add_argument(
|
||||
"dest_path", help="Path to rwkv.cpp checkpoint file, will be overwritten"
|
||||
)
|
||||
parser.add_argument(
|
||||
"data_type",
|
||||
help="Data type, FP16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0",
|
||||
type=str,
|
||||
choices=[
|
||||
"FP16",
|
||||
"Q4_0",
|
||||
"Q4_1",
|
||||
"Q5_0",
|
||||
"Q5_1",
|
||||
"Q8_0",
|
||||
],
|
||||
default="FP16",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
|
||||
n_layer: int = 0
|
||||
|
||||
while f"blocks.{n_layer}.ln1.weight" in state_dict:
|
||||
n_layer += 1
|
||||
|
||||
assert n_layer > 0
|
||||
|
||||
return n_layer
|
||||
|
||||
|
||||
def write_state_dict(
|
||||
state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str
|
||||
) -> None:
|
||||
emb_weight: torch.Tensor = state_dict["emb.weight"]
|
||||
|
||||
n_layer: int = get_layer_count(state_dict)
|
||||
n_vocab: int = emb_weight.shape[0]
|
||||
n_embed: int = emb_weight.shape[1]
|
||||
|
||||
is_v5_1_or_2: bool = "blocks.0.att.ln_x.weight" in state_dict
|
||||
is_v5_2: bool = "blocks.0.att.gate.weight" in state_dict
|
||||
|
||||
if is_v5_2:
|
||||
print("Detected RWKV v5.2")
|
||||
elif is_v5_1_or_2:
|
||||
print("Detected RWKV v5.1")
|
||||
else:
|
||||
print("Detected RWKV v4")
|
||||
|
||||
with open(dest_path, "wb") as out_file:
|
||||
is_FP16: bool = data_type == "FP16" or data_type == "float16"
|
||||
|
||||
out_file.write(
|
||||
struct.pack(
|
||||
# Disable padding with '='
|
||||
"=iiiiii",
|
||||
# Magic: 'ggmf' in hex
|
||||
0x67676D66,
|
||||
101,
|
||||
n_vocab,
|
||||
n_embed,
|
||||
n_layer,
|
||||
1 if is_FP16 else 0,
|
||||
)
|
||||
)
|
||||
|
||||
for k in state_dict.keys():
|
||||
tensor: torch.Tensor = state_dict[k].float()
|
||||
|
||||
if ".time_" in k:
|
||||
tensor = tensor.squeeze()
|
||||
|
||||
if is_v5_1_or_2:
|
||||
if ".time_decay" in k:
|
||||
if is_v5_2:
|
||||
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
|
||||
else:
|
||||
tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)
|
||||
|
||||
if ".time_first" in k:
|
||||
tensor = torch.exp(tensor).reshape(-1, 1, 1)
|
||||
|
||||
if ".time_faaaa" in k:
|
||||
tensor = tensor.unsqueeze(-1)
|
||||
else:
|
||||
if ".time_decay" in k:
|
||||
tensor = -torch.exp(tensor)
|
||||
|
||||
# Keep 1-dim vectors and small matrices in FP32
|
||||
if is_FP16 and len(tensor.shape) > 1 and ".time_" not in k:
|
||||
tensor = tensor.half()
|
||||
|
||||
shape = tensor.shape
|
||||
|
||||
print(f"Writing {k}, shape {shape}, type {tensor.dtype}")
|
||||
|
||||
k_encoded: bytes = k.encode("utf-8")
|
||||
|
||||
out_file.write(
|
||||
struct.pack(
|
||||
"=iii",
|
||||
len(shape),
|
||||
len(k_encoded),
|
||||
1 if tensor.dtype == torch.float16 else 0,
|
||||
)
|
||||
)
|
||||
|
||||
# Dimension order is reversed here:
|
||||
# * PyTorch shape is (x rows, y columns)
|
||||
# * ggml shape is (y elements in a row, x elements in a column)
|
||||
# Both shapes represent the same tensor.
|
||||
for dim in reversed(tensor.shape):
|
||||
out_file.write(struct.pack("=i", dim))
|
||||
|
||||
out_file.write(k_encoded)
|
||||
|
||||
tensor.numpy().tofile(out_file)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
print(f"Reading {args.src_path}")
|
||||
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(args.src_path, map_location="cpu")
|
||||
|
||||
temp_output: str = args.dest_path
|
||||
if args.data_type.startswith("Q"):
|
||||
import re
|
||||
|
||||
temp_output = re.sub(r"Q[4,5,8]_[0,1]", "fp16", temp_output)
|
||||
write_state_dict(state_dict, temp_output, "FP16")
|
||||
if args.data_type.startswith("Q"):
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
from rwkv_pip.cpp import rwkv_cpp_shared_library
|
||||
|
||||
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
library.rwkv_quantize_model_file(temp_output, args.dest_path, args.data_type)
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
with open("error.txt", "w") as f:
|
||||
f.write(str(e))
|
||||
89
backend-python/convert_safetensors.py
vendored
89
backend-python/convert_safetensors.py
vendored
@@ -1,9 +1,8 @@
|
||||
import json
|
||||
import collections
|
||||
import numpy
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import serialize_file, load_file
|
||||
|
||||
import argparse
|
||||
|
||||
@@ -25,34 +24,64 @@ def rename_key(rename, name):
|
||||
return name
|
||||
|
||||
|
||||
def convert_file(pt_filename: str, sf_filename: str, transpose_names=[], rename={}):
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
|
||||
loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
|
||||
loaded = {k: v.clone().half() for k, v in loaded.items()}
|
||||
# for k, v in loaded.items():
|
||||
# print(f'{k}\t{v.shape}\t{v.dtype}')
|
||||
kk = list(loaded.keys())
|
||||
version = 4
|
||||
for x in kk:
|
||||
if "ln_x" in x:
|
||||
version = max(5, version)
|
||||
if "gate.weight" in x:
|
||||
version = max(5.1, version)
|
||||
if int(version) == 5 and "att.time_decay" in x:
|
||||
if len(loaded[x].shape) > 1:
|
||||
if loaded[x].shape[1] > 1:
|
||||
version = max(5.2, version)
|
||||
if "time_maa" in x:
|
||||
version = max(6, version)
|
||||
|
||||
# For tensors to be contiguous
|
||||
for k, v in loaded.items():
|
||||
print(f"Model detected: v{version:.1f}")
|
||||
|
||||
if version == 5.1:
|
||||
_, n_emb = loaded["emb.weight"].shape
|
||||
for k in kk:
|
||||
if "time_decay" in k or "time_faaaa" in k:
|
||||
# print(k, mm[k].shape)
|
||||
loaded[k] = (
|
||||
loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
|
||||
)
|
||||
|
||||
for k in kk:
|
||||
new_k = rename_key(rename, k).lower()
|
||||
v = loaded[k].half()
|
||||
del loaded[k]
|
||||
for transpose_name in transpose_names:
|
||||
if transpose_name in k:
|
||||
loaded[k] = v.transpose(0, 1)
|
||||
loaded = {rename_key(rename, k).lower(): v.contiguous() for k, v in loaded.items()}
|
||||
|
||||
for k, v in loaded.items():
|
||||
print(f"{k}\t{v.shape}\t{v.dtype}")
|
||||
v = v.transpose(0, 1)
|
||||
print(f"{new_k}\t{v.shape}\t{v.dtype}")
|
||||
loaded[new_k] = {
|
||||
"dtype": str(v.dtype).split(".")[-1],
|
||||
"shape": v.shape,
|
||||
"data": v.numpy().tobytes(),
|
||||
}
|
||||
|
||||
dirname = os.path.dirname(sf_filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
reloaded = load_file(sf_filename)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
serialize_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
# reloaded = load_file(sf_filename)
|
||||
# for k in loaded:
|
||||
# pt_tensor = torch.Tensor(
|
||||
# numpy.frombuffer(
|
||||
# bytearray(loaded[k]["data"]),
|
||||
# dtype=getattr(numpy, loaded[k]["dtype"]),
|
||||
# ).reshape(loaded[k]["shape"])
|
||||
# )
|
||||
# sf_tensor = reloaded[k]
|
||||
# if not torch.equal(pt_tensor, sf_tensor):
|
||||
# raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -60,8 +89,18 @@ if __name__ == "__main__":
|
||||
convert_file(
|
||||
args.input,
|
||||
args.output,
|
||||
["lora_A"],
|
||||
{"time_faaaa": "time_first", "lora_A": "lora.0", "lora_B": "lora.1"},
|
||||
rename={
|
||||
"time_faaaa": "time_first",
|
||||
"time_maa": "time_mix",
|
||||
"lora_A": "lora.0",
|
||||
"lora_B": "lora.1",
|
||||
},
|
||||
transpose_names=[
|
||||
"time_mix_w1",
|
||||
"time_mix_w2",
|
||||
"time_decay_w1",
|
||||
"time_decay_w2",
|
||||
],
|
||||
)
|
||||
print(f"Saved to {args.output}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -32,6 +32,16 @@ def get_args(args: Union[Sequence[str], None] = None):
|
||||
action="store_true",
|
||||
help="whether to use rwkv-beta (default: False)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--rwkv.cpp",
|
||||
action="store_true",
|
||||
help="whether to use rwkv.cpp (default: False)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--webgpu",
|
||||
action="store_true",
|
||||
help="whether to use webgpu (default: False)",
|
||||
)
|
||||
args = parser.parse_args(args)
|
||||
|
||||
return args
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
rwkv==0.8.16
|
||||
rwkv==0.8.22
|
||||
langchain==0.0.322
|
||||
fastapi==0.104.0
|
||||
uvicorn==0.23.2
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
rwkv==0.8.16
|
||||
rwkv==0.8.22
|
||||
langchain==0.0.322
|
||||
fastapi==0.104.0
|
||||
uvicorn==0.23.2
|
||||
|
||||
@@ -8,7 +8,6 @@ import base64
|
||||
from fastapi import APIRouter, Request, status, HTTPException
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from utils.rwkv import *
|
||||
from utils.log import quick_log
|
||||
@@ -335,6 +334,8 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
body.stop.append(f"\n\n{bot_code}")
|
||||
elif body.stop is None:
|
||||
body.stop = default_stop
|
||||
if not body.presystem:
|
||||
body.stop.append("\n\n")
|
||||
|
||||
if body.stream:
|
||||
return EventSourceResponse(
|
||||
@@ -396,6 +397,8 @@ class EmbeddingsBody(BaseModel):
|
||||
|
||||
|
||||
def embedding_base64(embedding: List[float]) -> str:
|
||||
import numpy as np
|
||||
|
||||
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
|
||||
|
||||
|
||||
|
||||
@@ -49,13 +49,20 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||
if body.model == "":
|
||||
return "success"
|
||||
|
||||
if "->" in body.strategy:
|
||||
state_cache.disable_state_cache()
|
||||
else:
|
||||
try:
|
||||
state_cache.enable_state_cache()
|
||||
except HTTPException:
|
||||
pass
|
||||
devices = set(
|
||||
[
|
||||
x.strip().split(" ")[0].replace("cuda:0", "cuda")
|
||||
for x in body.strategy.split("->")
|
||||
]
|
||||
)
|
||||
print(f"Strategy Devices: {devices}")
|
||||
# if len(devices) > 1:
|
||||
# state_cache.disable_state_cache()
|
||||
# else:
|
||||
try:
|
||||
state_cache.enable_state_cache()
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
||||
|
||||
@@ -67,6 +74,10 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
print(traceback.format_exc())
|
||||
|
||||
quick_log(request, body, f"Exception: {e}")
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
||||
raise HTTPException(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import io
|
||||
import global_var
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, status
|
||||
from starlette.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from utils.midi import *
|
||||
@@ -33,6 +33,20 @@ def text_to_midi(body: TextToMidiBody):
|
||||
return StreamingResponse(mid_data, media_type="audio/midi")
|
||||
|
||||
|
||||
@router.post("/midi-to-text", tags=["MIDI"])
|
||||
async def midi_to_text(file_data: UploadFile):
|
||||
vocab_config = "backend-python/utils/midi_vocab_config.json"
|
||||
cfg = VocabConfig.from_json(vocab_config)
|
||||
filter_config = "backend-python/utils/midi_filter_config.json"
|
||||
filter_cfg = FilterConfig.from_json(filter_config)
|
||||
mid = mido.MidiFile(file=file_data.file)
|
||||
output_list = convert_midi_to_str(cfg, filter_cfg, mid)
|
||||
if len(output_list) == 0:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad midi file")
|
||||
|
||||
return {"text": output_list[0]}
|
||||
|
||||
|
||||
class TxtToMidiBody(BaseModel):
|
||||
txt_path: str
|
||||
midi_path: str
|
||||
|
||||
@@ -44,6 +44,7 @@ def disable_state_cache():
|
||||
dtrie = {}
|
||||
gc.collect()
|
||||
|
||||
print("state cache disabled")
|
||||
return "success"
|
||||
|
||||
|
||||
@@ -61,8 +62,10 @@ def enable_state_cache():
|
||||
dtrie = {}
|
||||
gc.collect()
|
||||
|
||||
print("state cache enabled")
|
||||
return "success"
|
||||
except ModuleNotFoundError:
|
||||
print("state cache disabled")
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
|
||||
|
||||
|
||||
@@ -84,17 +87,27 @@ def add_state(body: AddStateBody):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
devices: List[torch.device] = []
|
||||
state: Union[Any, None] = None
|
||||
|
||||
if body.state is not None:
|
||||
if type(body.state) == list and hasattr(body.state[0], "device"): # torch
|
||||
devices = [tensor.device for tensor in body.state]
|
||||
state = [tensor.cpu() for tensor in body.state]
|
||||
elif type(body.state) == np.ndarray: # rwkv.cpp
|
||||
state = body.state
|
||||
else: # WebGPU
|
||||
state = body.state.back()
|
||||
|
||||
id: int = trie.insert(body.prompt)
|
||||
device: torch.device = body.state[0].device
|
||||
dtrie[id] = {
|
||||
"tokens": copy.deepcopy(body.tokens),
|
||||
"state": [tensor.cpu() for tensor in body.state]
|
||||
if device != torch.device("cpu")
|
||||
else copy.deepcopy(body.state),
|
||||
"logits": copy.deepcopy(body.logits),
|
||||
"device": device,
|
||||
"tokens": body.tokens,
|
||||
"state": state,
|
||||
"logits": body.logits,
|
||||
"devices": devices,
|
||||
}
|
||||
|
||||
if len(trie) >= max_trie_len:
|
||||
@@ -168,6 +181,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
id = -1
|
||||
try:
|
||||
@@ -176,28 +190,23 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
except:
|
||||
pass
|
||||
if id != -1:
|
||||
v = dtrie[id]
|
||||
device: torch.device = v["device"]
|
||||
prompt: str = trie[id]
|
||||
v = dtrie[id]
|
||||
devices: List[torch.device] = v["devices"]
|
||||
state: Union[Any, None] = v["state"]
|
||||
|
||||
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
|
||||
|
||||
quick_log(request, body, "Hit:\n" + prompt)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"tokens": v["tokens"],
|
||||
"state": [tensor.to(device) for tensor in v["state"]]
|
||||
if device != torch.device("cpu")
|
||||
else v["state"],
|
||||
"state": state,
|
||||
"logits": v["logits"],
|
||||
"device": device.type,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"prompt": "",
|
||||
"tokens": [],
|
||||
"state": None,
|
||||
"logits": None,
|
||||
"device": None,
|
||||
}
|
||||
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
||||
|
||||
|
||||
# @router.post("/save-state", tags=["State Cache"])
|
||||
|
||||
2
backend-python/rwkv_pip/beta/model.py
vendored
2
backend-python/rwkv_pip/beta/model.py
vendored
@@ -251,7 +251,7 @@ class RWKV(MyModule):
|
||||
)
|
||||
assert (
|
||||
w["_strategy"] == args.strategy_string
|
||||
) # if you are using a new strategy, re-convert the model
|
||||
), "model has been converted and does not match current strategy; if you are using a new strategy, re-convert the model"
|
||||
assert (
|
||||
float(w["_version"]) >= 0.7
|
||||
) # sometimes you should re-convert using latest convert_model.py
|
||||
|
||||
BIN
backend-python/rwkv_pip/cpp/librwkv.dylib
vendored
Normal file
BIN
backend-python/rwkv_pip/cpp/librwkv.dylib
vendored
Normal file
Binary file not shown.
BIN
backend-python/rwkv_pip/cpp/librwkv.so
vendored
Normal file
BIN
backend-python/rwkv_pip/cpp/librwkv.so
vendored
Normal file
Binary file not shown.
14
backend-python/rwkv_pip/cpp/model.py
vendored
Normal file
14
backend-python/rwkv_pip/cpp/model.py
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import Any, List, Union
|
||||
from . import rwkv_cpp_model
|
||||
from . import rwkv_cpp_shared_library
|
||||
|
||||
|
||||
class RWKV:
|
||||
def __init__(self, model_path: str, strategy=None):
|
||||
self.library = rwkv_cpp_shared_library.load_rwkv_shared_library()
|
||||
self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
|
||||
self.w = {} # fake weight
|
||||
self.w["emb.weight"] = [0] * self.model.n_vocab
|
||||
|
||||
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
||||
return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)
|
||||
BIN
backend-python/rwkv_pip/cpp/rwkv.dll
vendored
Normal file
BIN
backend-python/rwkv_pip/cpp/rwkv.dll
vendored
Normal file
Binary file not shown.
369
backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
vendored
Normal file
369
backend-python/rwkv_pip/cpp/rwkv_cpp_model.py
vendored
Normal file
@@ -0,0 +1,369 @@
|
||||
import os
|
||||
import multiprocessing
|
||||
|
||||
# Pre-import PyTorch, if available.
|
||||
# This fixes "OSError: [WinError 127] The specified procedure could not be found".
|
||||
try:
|
||||
import torch
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
# I'm sure this is not strictly correct, but let's keep this crutch for now.
|
||||
try:
|
||||
import rwkv_cpp_shared_library
|
||||
except ModuleNotFoundError:
|
||||
from . import rwkv_cpp_shared_library
|
||||
|
||||
from typing import TypeVar, Optional, Tuple, List
|
||||
|
||||
# A value of this type is either a numpy's ndarray or a PyTorch's Tensor.
|
||||
NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor')
|
||||
|
||||
class RWKVModel:
|
||||
"""
|
||||
An RWKV model managed by rwkv.cpp library.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
|
||||
model_path: str,
|
||||
thread_count: int = max(1, multiprocessing.cpu_count() // 2),
|
||||
gpu_layer_count: int = 0,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Loads the model and prepares it for inference.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_library : RWKVSharedLibrary
|
||||
rwkv.cpp shared library.
|
||||
model_path : str
|
||||
Path to RWKV model file in ggml format.
|
||||
thread_count : int
|
||||
Thread count to use. If not set, defaults to CPU count / 2.
|
||||
gpu_layer_count : int
|
||||
Count of layers to offload onto the GPU, must be >= 0.
|
||||
See documentation of `gpu_offload_layers` for details about layer offloading.
|
||||
"""
|
||||
|
||||
if 'gpu_layers_count' in kwargs:
|
||||
gpu_layer_count = kwargs['gpu_layers_count']
|
||||
|
||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||
assert thread_count > 0, 'Thread count must be > 0'
|
||||
assert gpu_layer_count >= 0, 'GPU layer count must be >= 0'
|
||||
|
||||
self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
|
||||
|
||||
self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count)
|
||||
|
||||
if gpu_layer_count > 0:
|
||||
self.gpu_offload_layers(gpu_layer_count)
|
||||
|
||||
self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx)
|
||||
self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
|
||||
|
||||
self._valid: bool = True
|
||||
|
||||
def gpu_offload_layers(self, layer_count: int) -> bool:
|
||||
"""
|
||||
Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
|
||||
For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
|
||||
- pass `model.n_layer` to offload all layers except model head
|
||||
- pass `model.n_layer + 1` to offload all layers, including model head
|
||||
|
||||
Returns true if at least one layer was offloaded.
|
||||
If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
layer_count : int
|
||||
Count of layers to offload onto the GPU, must be >= 0.
|
||||
"""
|
||||
|
||||
assert layer_count >= 0, 'Layer count must be >= 0'
|
||||
|
||||
return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
|
||||
|
||||
@property
|
||||
def n_vocab(self) -> int:
|
||||
return self._library.rwkv_get_n_vocab(self._ctx)
|
||||
|
||||
@property
|
||||
def n_embed(self) -> int:
|
||||
return self._library.rwkv_get_n_embed(self._ctx)
|
||||
|
||||
@property
|
||||
def n_layer(self) -> int:
|
||||
return self._library.rwkv_get_n_layer(self._ctx)
|
||||
|
||||
def eval(
|
||||
self,
|
||||
token: int,
|
||||
state_in: Optional[NumpyArrayOrPyTorchTensor],
|
||||
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
use_numpy: bool = False
|
||||
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
|
||||
"""
|
||||
Evaluates the model for a single token.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
token : int
|
||||
Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||
state_in : Optional[NumpyArrayOrTorchTensor]
|
||||
State from previous call of this method. If this is a first pass, set it to None.
|
||||
state_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
||||
logits_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
||||
use_numpy : bool
|
||||
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
|
||||
This parameter is ignored if any tensor parameter is not None; in such case,
|
||||
type of returned tensors will match the type of received tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits, state
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Model was freed'
|
||||
|
||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||
|
||||
if state_in is not None:
|
||||
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
|
||||
|
||||
state_in_ptr = self._get_data_ptr(state_in)
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
if state_out is not None:
|
||||
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
|
||||
else:
|
||||
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
|
||||
|
||||
if logits_out is not None:
|
||||
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
|
||||
else:
|
||||
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
|
||||
|
||||
self._library.rwkv_eval(
|
||||
self._ctx,
|
||||
token,
|
||||
state_in_ptr,
|
||||
self._get_data_ptr(state_out),
|
||||
self._get_data_ptr(logits_out)
|
||||
)
|
||||
|
||||
return logits_out, state_out
|
||||
|
||||
def eval_sequence(
|
||||
self,
|
||||
tokens: List[int],
|
||||
state_in: Optional[NumpyArrayOrPyTorchTensor],
|
||||
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
use_numpy: bool = False
|
||||
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
|
||||
"""
|
||||
Evaluates the model for a sequence of tokens.
|
||||
|
||||
NOTE ON GGML NODE LIMIT
|
||||
|
||||
ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
|
||||
this limit when using large models and/or large sequence lengths.
|
||||
Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
|
||||
|
||||
If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
|
||||
To get rid of the assertion failure, reduce the model size and/or sequence length.
|
||||
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : List[int]
|
||||
Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||
state_in : Optional[NumpyArrayOrTorchTensor]
|
||||
State from previous call of this method. If this is a first pass, set it to None.
|
||||
state_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
||||
logits_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
||||
use_numpy : bool
|
||||
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
|
||||
This parameter is ignored if any tensor parameter is not None; in such case,
|
||||
type of returned tensors will match the type of received tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits, state
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Model was freed'
|
||||
|
||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||
|
||||
if state_in is not None:
|
||||
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
|
||||
|
||||
state_in_ptr = self._get_data_ptr(state_in)
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
if state_out is not None:
|
||||
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
|
||||
else:
|
||||
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
|
||||
|
||||
if logits_out is not None:
|
||||
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
|
||||
else:
|
||||
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
|
||||
|
||||
self._library.rwkv_eval_sequence(
|
||||
self._ctx,
|
||||
tokens,
|
||||
state_in_ptr,
|
||||
self._get_data_ptr(state_out),
|
||||
self._get_data_ptr(logits_out)
|
||||
)
|
||||
|
||||
return logits_out, state_out
|
||||
|
||||
def eval_sequence_in_chunks(
|
||||
self,
|
||||
tokens: List[int],
|
||||
state_in: Optional[NumpyArrayOrPyTorchTensor],
|
||||
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
|
||||
chunk_size: int = 16,
|
||||
use_numpy: bool = False
|
||||
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
|
||||
"""
|
||||
Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
|
||||
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
|
||||
It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance.
|
||||
|
||||
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
|
||||
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
|
||||
and choose one that works the best in your use case.
|
||||
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : List[int]
|
||||
Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
|
||||
chunk_size : int
|
||||
Size of each chunk in tokens, must be positive.
|
||||
state_in : Optional[NumpyArrayOrTorchTensor]
|
||||
State from previous call of this method. If this is a first pass, set it to None.
|
||||
state_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
|
||||
logits_out : Optional[NumpyArrayOrTorchTensor]
|
||||
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
|
||||
use_numpy : bool
|
||||
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
|
||||
This parameter is ignored if any tensor parameter is not None; in such case,
|
||||
type of returned tensors will match the type of received tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
logits, state
|
||||
Logits vector of shape (n_vocab); state for the next step.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Model was freed'
|
||||
|
||||
use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
|
||||
|
||||
if state_in is not None:
|
||||
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
|
||||
|
||||
state_in_ptr = self._get_data_ptr(state_in)
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
if state_out is not None:
|
||||
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
|
||||
else:
|
||||
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
|
||||
|
||||
if logits_out is not None:
|
||||
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
|
||||
else:
|
||||
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
|
||||
|
||||
self._library.rwkv_eval_sequence_in_chunks(
|
||||
self._ctx,
|
||||
tokens,
|
||||
chunk_size,
|
||||
state_in_ptr,
|
||||
self._get_data_ptr(state_out),
|
||||
self._get_data_ptr(logits_out)
|
||||
)
|
||||
|
||||
return logits_out, state_out
|
||||
|
||||
def free(self) -> None:
|
||||
"""
|
||||
Frees all allocated resources.
|
||||
In case of any error, this method will throw an exception.
|
||||
The object must not be used anymore after calling this method.
|
||||
"""
|
||||
|
||||
assert self._valid, 'Already freed'
|
||||
|
||||
self._valid = False
|
||||
|
||||
self._library.rwkv_free(self._ctx)
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Free the context on GC in case user forgot to call free() explicitly.
|
||||
if hasattr(self, '_valid') and self._valid:
|
||||
self.free()
|
||||
|
||||
def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool:
|
||||
return hasattr(tensor, '__module__') and tensor.__module__ == 'torch'
|
||||
|
||||
def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool:
|
||||
for tensor in tensors:
|
||||
if tensor is not None:
|
||||
return False if self._is_pytorch_tensor(tensor) else True
|
||||
|
||||
return use_numpy_by_default
|
||||
|
||||
def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
|
||||
if self._is_pytorch_tensor(tensor):
|
||||
tensor: torch.Tensor = tensor
|
||||
assert tensor.device == torch.device('cpu'), f'{name} is not on CPU'
|
||||
assert tensor.dtype == torch.float32, f'{name} is not of type float32'
|
||||
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
|
||||
assert tensor.is_contiguous(), f'{name} is not contiguous'
|
||||
else:
|
||||
import numpy as np
|
||||
tensor: np.ndarray = tensor
|
||||
assert tensor.dtype == np.float32, f'{name} is not of type float32'
|
||||
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
|
||||
assert tensor.data.contiguous, f'{name} is not contiguous'
|
||||
|
||||
def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
|
||||
if self._is_pytorch_tensor(tensor):
|
||||
return tensor.data_ptr()
|
||||
else:
|
||||
return tensor.ctypes.data
|
||||
|
||||
def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor:
|
||||
if use_numpy:
|
||||
import numpy as np
|
||||
return np.zeros(element_count, dtype=np.float32)
|
||||
else:
|
||||
return torch.zeros(element_count, dtype=torch.float32, device='cpu')
|
||||
444
backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py
vendored
Normal file
444
backend-python/rwkv_pip/cpp/rwkv_cpp_shared_library.py
vendored
Normal file
@@ -0,0 +1,444 @@
|
||||
import os
|
||||
import sys
|
||||
import ctypes
|
||||
import pathlib
|
||||
import platform
|
||||
from typing import Optional, List, Tuple, Callable
|
||||
|
||||
QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
|
||||
'Q4_0',
|
||||
'Q4_1',
|
||||
'Q5_0',
|
||||
'Q5_1',
|
||||
'Q8_0'
|
||||
)
|
||||
|
||||
P_FLOAT = ctypes.POINTER(ctypes.c_float)
|
||||
P_INT = ctypes.POINTER(ctypes.c_int32)
|
||||
|
||||
class RWKVContext:
|
||||
|
||||
def __init__(self, ptr: ctypes.pointer) -> None:
|
||||
self.ptr: ctypes.pointer = ptr
|
||||
|
||||
class RWKVSharedLibrary:
|
||||
"""
|
||||
Python wrapper around rwkv.cpp shared library.
|
||||
"""
|
||||
|
||||
def __init__(self, shared_library_path: str) -> None:
|
||||
"""
|
||||
Loads the shared library from specified file.
|
||||
In case of any error, this method will throw an exception.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_library_path : str
|
||||
Path to rwkv.cpp shared library. On Windows, it would look like 'rwkv.dll'. On UNIX, 'rwkv.so'.
|
||||
"""
|
||||
# When Python is greater than 3.8, we need to reprocess the custom dll
|
||||
# according to the documentation to prevent loading failure errors.
|
||||
# https://docs.python.org/3/whatsnew/3.8.html#ctypes
|
||||
if platform.system().lower() == 'windows':
|
||||
self.library = ctypes.CDLL(shared_library_path, winmode=0)
|
||||
else:
|
||||
self.library = ctypes.cdll.LoadLibrary(shared_library_path)
|
||||
|
||||
self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32]
|
||||
self.library.rwkv_init_from_file.restype = ctypes.c_void_p
|
||||
|
||||
self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
|
||||
self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_eval.argtypes = [
|
||||
ctypes.c_void_p, # ctx
|
||||
ctypes.c_int32, # token
|
||||
P_FLOAT, # state_in
|
||||
P_FLOAT, # state_out
|
||||
P_FLOAT # logits_out
|
||||
]
|
||||
self.library.rwkv_eval.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_eval_sequence.argtypes = [
|
||||
ctypes.c_void_p, # ctx
|
||||
P_INT, # tokens
|
||||
ctypes.c_size_t, # token count
|
||||
P_FLOAT, # state_in
|
||||
P_FLOAT, # state_out
|
||||
P_FLOAT # logits_out
|
||||
]
|
||||
self.library.rwkv_eval_sequence.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_eval_sequence_in_chunks.argtypes = [
|
||||
ctypes.c_void_p, # ctx
|
||||
P_INT, # tokens
|
||||
ctypes.c_size_t, # token count
|
||||
ctypes.c_size_t, # chunk size
|
||||
P_FLOAT, # state_in
|
||||
P_FLOAT, # state_out
|
||||
P_FLOAT # logits_out
|
||||
]
|
||||
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_get_n_embed.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_n_embed.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_get_n_layer.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_n_layer.restype = ctypes.c_size_t
|
||||
|
||||
self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32
|
||||
|
||||
self.library.rwkv_get_logits_buffer_element_count.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_get_logits_buffer_element_count.restype = ctypes.c_uint32
|
||||
|
||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_free.restype = None
|
||||
|
||||
self.library.rwkv_free.argtypes = [ctypes.c_void_p]
|
||||
self.library.rwkv_free.restype = None
|
||||
|
||||
self.library.rwkv_quantize_model_file.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p]
|
||||
self.library.rwkv_quantize_model_file.restype = ctypes.c_bool
|
||||
|
||||
self.library.rwkv_get_system_info_string.argtypes = []
|
||||
self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p
|
||||
|
||||
self.nullptr = ctypes.cast(0, ctypes.c_void_p)
|
||||
|
||||
def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext:
|
||||
"""
|
||||
Loads the model from a file and prepares it for inference.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_file_path : str
|
||||
Path to model file in ggml format.
|
||||
thread_count : int
|
||||
Count of threads to use, must be positive.
|
||||
"""
|
||||
|
||||
ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))
|
||||
|
||||
assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
|
||||
|
||||
return RWKVContext(ptr)
|
||||
|
||||
def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool:
|
||||
"""
|
||||
Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
|
||||
For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
|
||||
- pass `rwkv_get_n_layer(ctx)` to offload all layers except model head
|
||||
- pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head
|
||||
Returns true if at least one layer was offloaded.
|
||||
If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
layer_count : int
|
||||
Count of layers to offload onto the GPU, must be >= 0.
|
||||
"""
|
||||
|
||||
assert layer_count >= 0, 'Layer count must be >= 0'
|
||||
|
||||
return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count))
|
||||
|
||||
def rwkv_eval(
|
||||
self,
|
||||
ctx: RWKVContext,
|
||||
token: int,
|
||||
state_in_address: Optional[int],
|
||||
state_out_address: int,
|
||||
logits_out_address: int
|
||||
) -> None:
|
||||
"""
|
||||
Evaluates the model for a single token.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
token : int
|
||||
Next token index, in range 0 <= token < n_vocab.
|
||||
state_in_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
||||
state_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
logits_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
"""
|
||||
|
||||
assert self.library.rwkv_eval(
|
||||
ctx.ptr,
|
||||
ctypes.c_int32(token),
|
||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||
ctypes.cast(state_out_address, P_FLOAT),
|
||||
ctypes.cast(logits_out_address, P_FLOAT)
|
||||
), 'rwkv_eval failed, check stderr'
|
||||
|
||||
def rwkv_eval_sequence(
|
||||
self,
|
||||
ctx: RWKVContext,
|
||||
tokens: List[int],
|
||||
state_in_address: Optional[int],
|
||||
state_out_address: int,
|
||||
logits_out_address: int
|
||||
) -> None:
|
||||
"""
|
||||
Evaluates the model for a sequence of tokens.
|
||||
Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
|
||||
Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
|
||||
|
||||
NOTE ON GGML NODE LIMIT
|
||||
|
||||
ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
|
||||
this limit when using large models and/or large sequence lengths.
|
||||
Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
|
||||
|
||||
If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
|
||||
To get rid of the assertion failure, reduce the model size and/or sequence length.
|
||||
|
||||
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
tokens : List[int]
|
||||
Next token indices, in range 0 <= token < n_vocab.
|
||||
state_in_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
||||
state_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
logits_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
"""
|
||||
|
||||
assert self.library.rwkv_eval_sequence(
|
||||
ctx.ptr,
|
||||
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
||||
ctypes.c_size_t(len(tokens)),
|
||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||
ctypes.cast(state_out_address, P_FLOAT),
|
||||
ctypes.cast(logits_out_address, P_FLOAT)
|
||||
), 'rwkv_eval_sequence failed, check stderr'
|
||||
|
||||
def rwkv_eval_sequence_in_chunks(
|
||||
self,
|
||||
ctx: RWKVContext,
|
||||
tokens: List[int],
|
||||
chunk_size: int,
|
||||
state_in_address: Optional[int],
|
||||
state_out_address: int,
|
||||
logits_out_address: int
|
||||
) -> None:
|
||||
"""
|
||||
Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
|
||||
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
|
||||
It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
|
||||
|
||||
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
|
||||
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
|
||||
and choose one that works the best in your use case.
|
||||
|
||||
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
tokens : List[int]
|
||||
Next token indices, in range 0 <= token < n_vocab.
|
||||
chunk_size : int
|
||||
Size of each chunk in tokens, must be positive.
|
||||
state_in_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
|
||||
state_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
logits_out_address : int
|
||||
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
"""
|
||||
|
||||
assert self.library.rwkv_eval_sequence_in_chunks(
|
||||
ctx.ptr,
|
||||
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
|
||||
ctypes.c_size_t(len(tokens)),
|
||||
ctypes.c_size_t(chunk_size),
|
||||
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
|
||||
ctypes.cast(state_out_address, P_FLOAT),
|
||||
ctypes.cast(logits_out_address, P_FLOAT)
|
||||
), 'rwkv_eval_sequence_in_chunks failed, check stderr'
|
||||
|
||||
def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given model's vocabulary.
|
||||
Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_n_vocab(ctx.ptr)
|
||||
|
||||
def rwkv_get_n_embed(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns the number of elements in the given model's embedding.
|
||||
Useful for reading individual fields of a model's hidden state.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_n_embed(ctx.ptr)
|
||||
|
||||
def rwkv_get_n_layer(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns the number of layers in the given model.
|
||||
A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model.
|
||||
Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`.
|
||||
Useful for always offloading the entire model to GPU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_n_layer(ctx.ptr)
|
||||
|
||||
def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns count of FP32 elements in state buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_state_buffer_element_count(ctx.ptr)
|
||||
|
||||
def rwkv_get_logits_buffer_element_count(self, ctx: RWKVContext) -> int:
|
||||
"""
|
||||
Returns count of FP32 elements in logits buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_logits_buffer_element_count(ctx.ptr)
|
||||
|
||||
def rwkv_free(self, ctx: RWKVContext) -> None:
|
||||
"""
|
||||
Frees all allocated memory and the context.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx : RWKVContext
|
||||
RWKV context obtained from rwkv_init_from_file.
|
||||
"""
|
||||
|
||||
self.library.rwkv_free(ctx.ptr)
|
||||
|
||||
ctx.ptr = self.nullptr
|
||||
|
||||
def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None:
|
||||
"""
|
||||
Quantizes FP32 or FP16 model to one of INT4 formats.
|
||||
Throws an exception in case of any error. Error messages would be printed to stderr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_file_path_in : str
|
||||
Path to model file in ggml format, must be either FP32 or FP16.
|
||||
model_file_path_out : str
|
||||
Quantized model will be written here.
|
||||
format_name : str
|
||||
One of QUANTIZED_FORMAT_NAMES.
|
||||
"""
|
||||
|
||||
assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}'
|
||||
|
||||
assert self.library.rwkv_quantize_model_file(
|
||||
model_file_path_in.encode('utf-8'),
|
||||
model_file_path_out.encode('utf-8'),
|
||||
format_name.encode('utf-8')
|
||||
), 'rwkv_quantize_model_file failed, check stderr'
|
||||
|
||||
def rwkv_get_system_info_string(self) -> str:
|
||||
"""
|
||||
Returns system information string.
|
||||
"""
|
||||
|
||||
return self.library.rwkv_get_system_info_string().decode('utf-8')
|
||||
|
||||
def load_rwkv_shared_library() -> RWKVSharedLibrary:
|
||||
"""
|
||||
Attempts to find rwkv.cpp shared library and load it.
|
||||
To specify exact path to the library, create an instance of RWKVSharedLibrary explicitly.
|
||||
"""
|
||||
|
||||
file_name: str
|
||||
|
||||
if 'win32' in sys.platform or 'cygwin' in sys.platform:
|
||||
file_name = 'rwkv.dll'
|
||||
elif 'darwin' in sys.platform:
|
||||
file_name = 'librwkv.dylib'
|
||||
else:
|
||||
file_name = 'librwkv.so'
|
||||
|
||||
# Possible sub-paths to the library relative to the repo dir.
|
||||
child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
|
||||
# No lookup for Debug config here.
|
||||
# I assume that if a user wants to debug the library,
|
||||
# they will be able to find the library and set the exact path explicitly.
|
||||
lambda p: p / 'backend-python' / 'rwkv_pip' / 'cpp' / file_name,
|
||||
lambda p: p / 'bin' / 'Release' / file_name,
|
||||
lambda p: p / 'bin' / file_name,
|
||||
# Some people prefer to build in the "build" subdirectory.
|
||||
lambda p: p / 'build' / 'bin' / 'Release' / file_name,
|
||||
lambda p: p / 'build' / 'bin' / file_name,
|
||||
lambda p: p / 'build' / file_name,
|
||||
# Fallback.
|
||||
lambda p: p / file_name
|
||||
]
|
||||
|
||||
working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
|
||||
|
||||
parent_paths: List[pathlib.Path] = [
|
||||
# Possible repo dirs relative to the working dir.
|
||||
# ./python/rwkv_cpp
|
||||
working_dir.parent.parent,
|
||||
# ./python
|
||||
working_dir.parent,
|
||||
# .
|
||||
working_dir,
|
||||
# Repo dir relative to this Python file.
|
||||
pathlib.Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
]
|
||||
|
||||
for parent_path in parent_paths:
|
||||
for child_path in child_paths:
|
||||
full_path: pathlib.Path = child_path(parent_path)
|
||||
|
||||
if os.path.isfile(full_path):
|
||||
return RWKVSharedLibrary(str(full_path))
|
||||
|
||||
assert False, (f'Failed to find {file_name} automatically; '
|
||||
f'you need to find the library and create RWKVSharedLibrary specifying the path to it')
|
||||
2
backend-python/rwkv_pip/model.py
vendored
2
backend-python/rwkv_pip/model.py
vendored
@@ -342,7 +342,7 @@ class RWKV(MyModule):
|
||||
)
|
||||
assert (
|
||||
w["_strategy"] == args.strategy_string
|
||||
) # if you are using a new strategy, re-convert the model
|
||||
), "model has been converted and does not match current strategy; if you are using a new strategy, re-convert the model"
|
||||
assert (
|
||||
float(w["_version"]) >= 0.7
|
||||
) # sometimes you should re-convert using latest convert_model.py
|
||||
|
||||
65532
backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt
vendored
Normal file
65532
backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
39
backend-python/rwkv_pip/utils.py
vendored
39
backend-python/rwkv_pip/utils.py
vendored
@@ -34,6 +34,25 @@ class PIPELINE_ARGS:
|
||||
)
|
||||
|
||||
|
||||
class ABC_TOKENIZER:
|
||||
def __init__(self):
|
||||
self.pad_token_id = 0
|
||||
self.bos_token_id = 2
|
||||
self.eos_token_id = 3
|
||||
|
||||
def encode(self, text):
|
||||
ids = [ord(c) for c in text]
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
txt = "".join(
|
||||
chr(idx) if idx > self.eos_token_id else ""
|
||||
for idx in ids
|
||||
if idx != self.eos_token_id
|
||||
)
|
||||
return txt
|
||||
|
||||
|
||||
class PIPELINE:
|
||||
def __init__(self, model, WORD_NAME: str):
|
||||
self.model = model
|
||||
@@ -48,6 +67,8 @@ class PIPELINE:
|
||||
self.tokenizer = TRIE_TOKENIZER(
|
||||
os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
|
||||
)
|
||||
elif WORD_NAME == "abc_tokenizer":
|
||||
self.tokenizer = ABC_TOKENIZER()
|
||||
else:
|
||||
if WORD_NAME.endswith(".txt"):
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -78,12 +99,24 @@ class PIPELINE:
|
||||
def decode(self, x):
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def np_softmax(self, x: np.ndarray, axis: int):
|
||||
x -= x.max(axis=axis, keepdims=True)
|
||||
e: np.ndarray = np.exp(x)
|
||||
return e / e.sum(axis=axis, keepdims=True)
|
||||
|
||||
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
||||
probs = F.softmax(logits.float(), dim=-1)
|
||||
if type(logits) == list:
|
||||
logits = np.array(logits)
|
||||
np_logits = type(logits) == np.ndarray
|
||||
if np_logits:
|
||||
probs = self.np_softmax(logits, axis=-1)
|
||||
else:
|
||||
probs = F.softmax(logits.float(), dim=-1)
|
||||
top_k = int(top_k)
|
||||
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
|
||||
if probs.device.type in ["cpu", "privateuseone"]:
|
||||
probs = probs.cpu().numpy()
|
||||
if np_logits or probs.device.type in ["cpu", "privateuseone"]:
|
||||
if not np_logits:
|
||||
probs = probs.cpu().numpy()
|
||||
sorted_ids = np.argsort(probs)
|
||||
sorted_probs = probs[sorted_ids][::-1]
|
||||
cumulative_probs = np.cumsum(sorted_probs)
|
||||
|
||||
31
backend-python/rwkv_pip/webgpu/model.py
vendored
Normal file
31
backend-python/rwkv_pip/webgpu/model.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Any, List, Union
|
||||
|
||||
try:
|
||||
import web_rwkv_py as wrp
|
||||
except ModuleNotFoundError:
|
||||
try:
|
||||
from . import web_rwkv_py as wrp
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py"
|
||||
)
|
||||
|
||||
|
||||
class RWKV:
|
||||
def __init__(self, model_path: str, strategy: str = None):
|
||||
self.model = wrp.v5.Model(
|
||||
model_path,
|
||||
turbo=True,
|
||||
quant=32 if "i8" in strategy else None,
|
||||
quant_nf4=26 if "i4" in strategy else None,
|
||||
)
|
||||
self.w = {} # fake weight
|
||||
self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab
|
||||
|
||||
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
||||
if type(state).__name__ == "BackedState": # memory state
|
||||
gpu_state = wrp.v5.ModelState(self.model, 1)
|
||||
gpu_state.load(state)
|
||||
else:
|
||||
gpu_state = state
|
||||
return wrp.v5.run_one(self.model, tokens, gpu_state)
|
||||
BIN
backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd
vendored
Normal file
BIN
backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd
vendored
Normal file
Binary file not shown.
71
backend-python/utils/midi.py
vendored
71
backend-python/utils/midi.py
vendored
@@ -52,6 +52,8 @@ class VocabConfig:
|
||||
bin_name_to_program_name: Dict[str, str]
|
||||
# Mapping from program number to instrument name.
|
||||
instrument_names: Dict[str, str]
|
||||
# Manual override for velocity bins. Each element is the max velocity value for that bin by index.
|
||||
velocity_bins_override: Optional[List[int]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.validate()
|
||||
@@ -116,6 +118,12 @@ class VocabConfig:
|
||||
raise ValueError("velocity_bins must be at least 2")
|
||||
if len(self.bin_instrument_names) > 16:
|
||||
raise ValueError("bin_instruments must have at most 16 values")
|
||||
if self.velocity_bins_override:
|
||||
print("VocabConfig is using velocity_bins_override. Ignoring velocity_exp.")
|
||||
if len(self.velocity_bins_override) != self.velocity_bins:
|
||||
raise ValueError(
|
||||
"velocity_bins_override must have same length as velocity_bins"
|
||||
)
|
||||
if (
|
||||
self.ch10_instrument_bin_name
|
||||
and self.ch10_instrument_bin_name not in self.bin_instrument_names
|
||||
@@ -156,6 +164,11 @@ class VocabUtils:
|
||||
|
||||
def velocity_to_bin(self, velocity: float) -> int:
|
||||
velocity = max(0, min(velocity, self.cfg.velocity_events - 1))
|
||||
if self.cfg.velocity_bins_override:
|
||||
for i, v in enumerate(self.cfg.velocity_bins_override):
|
||||
if velocity <= v:
|
||||
return i
|
||||
return 0
|
||||
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
|
||||
if self.cfg.velocity_exp == 1.0:
|
||||
return ceil(velocity / binsize)
|
||||
@@ -176,6 +189,8 @@ class VocabUtils:
|
||||
)
|
||||
|
||||
def bin_to_velocity(self, bin: int) -> int:
|
||||
if self.cfg.velocity_bins_override:
|
||||
return self.cfg.velocity_bins_override[bin]
|
||||
binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1)
|
||||
if self.cfg.velocity_exp == 1.0:
|
||||
return max(0, ceil(bin * binsize - 1))
|
||||
@@ -358,13 +373,32 @@ class AugmentConfig:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterConfig:
|
||||
# Whether to filter out MIDI files with duplicate MD5 hashes.
|
||||
deduplicate_md5: bool
|
||||
# Minimum time delay between notes in a file before splitting into multiple documents.
|
||||
piece_split_delay: float
|
||||
# Minimum length of a piece in milliseconds.
|
||||
min_piece_length: float
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: str):
|
||||
with open(path, "r") as f:
|
||||
config = json.load(f)
|
||||
return cls(**config)
|
||||
|
||||
|
||||
def mix_volume(velocity: int, volume: int, expression: int) -> float:
|
||||
return velocity * (volume / 127.0) * (expression / 127.0)
|
||||
|
||||
|
||||
def convert_midi_to_str(
|
||||
cfg: VocabConfig, mid: mido.MidiFile, augment: AugmentValues = None
|
||||
) -> str:
|
||||
cfg: VocabConfig,
|
||||
filter_cfg: FilterConfig,
|
||||
mid: mido.MidiFile,
|
||||
augment: AugmentValues = None,
|
||||
) -> List[str]:
|
||||
utils = VocabUtils(cfg)
|
||||
if augment is None:
|
||||
augment = AugmentValues.default()
|
||||
@@ -390,7 +424,9 @@ def convert_midi_to_str(
|
||||
} # {channel: {(note, program) -> True}}
|
||||
started_flag = False
|
||||
|
||||
output_list = []
|
||||
output = ["<start>"]
|
||||
output_length_ms = 0.0
|
||||
token_data_buffer: List[
|
||||
Tuple[int, int, int, float]
|
||||
] = [] # need to sort notes between wait tokens
|
||||
@@ -432,16 +468,33 @@ def convert_midi_to_str(
|
||||
token_data_buffer = []
|
||||
|
||||
def consume_note_program_data(prog: int, chan: int, note: int, vel: float):
|
||||
nonlocal output, started_flag, delta_time_ms, cfg, utils, token_data_buffer
|
||||
nonlocal output, output_length_ms, started_flag, delta_time_ms, cfg, utils, token_data_buffer
|
||||
is_token_valid = (
|
||||
utils.prog_data_to_token_data(prog, chan, note, vel) is not None
|
||||
)
|
||||
if not is_token_valid:
|
||||
return
|
||||
|
||||
if delta_time_ms > filter_cfg.piece_split_delay * 1000.0:
|
||||
# check if any notes are still held
|
||||
silent = True
|
||||
for channel in channel_notes.keys():
|
||||
if len(channel_notes[channel]) > 0:
|
||||
silent = False
|
||||
break
|
||||
if silent:
|
||||
flush_token_data_buffer()
|
||||
output.append("<end>")
|
||||
if output_length_ms > filter_cfg.min_piece_length * 1000.0:
|
||||
output_list.append(" ".join(output))
|
||||
output = ["<start>"]
|
||||
output_length_ms = 0.0
|
||||
started_flag = False
|
||||
if started_flag:
|
||||
wait_tokens = utils.data_to_wait_tokens(delta_time_ms)
|
||||
if len(wait_tokens) > 0:
|
||||
flush_token_data_buffer()
|
||||
output_length_ms += delta_time_ms
|
||||
output += wait_tokens
|
||||
delta_time_ms = 0.0
|
||||
token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor))
|
||||
@@ -510,7 +563,9 @@ def convert_midi_to_str(
|
||||
|
||||
flush_token_data_buffer()
|
||||
output.append("<end>")
|
||||
return " ".join(output)
|
||||
if output_length_ms > filter_cfg.min_piece_length * 1000.0:
|
||||
output_list.append(" ".join(output))
|
||||
return output_list
|
||||
|
||||
|
||||
def generate_program_change_messages(cfg: VocabConfig):
|
||||
@@ -633,10 +688,10 @@ def token_to_midi_message(
|
||||
if utils.cfg.decode_fix_repeated_notes:
|
||||
if (channel, note) in state.active_notes:
|
||||
del state.active_notes[(channel, note)]
|
||||
yield mido.Message(
|
||||
"note_off", note=note, time=ticks, channel=channel
|
||||
), state
|
||||
ticks = 0
|
||||
yield mido.Message(
|
||||
"note_off", note=note, time=ticks, channel=channel
|
||||
), state
|
||||
ticks = 0
|
||||
state.active_notes[(channel, note)] = state.total_time
|
||||
yield mido.Message(
|
||||
"note_on", note=note, velocity=velocity, time=ticks, channel=channel
|
||||
|
||||
5
backend-python/utils/midi_filter_config.json
Normal file
5
backend-python/utils/midi_filter_config.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"deduplicate_md5": true,
|
||||
"piece_split_delay": 10000,
|
||||
"min_piece_length": 0
|
||||
}
|
||||
@@ -8,15 +8,9 @@ from typing import Dict, Iterable, List, Tuple, Union, Type
|
||||
from utils.log import quick_log
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
from routes import state_cache
|
||||
import global_var
|
||||
|
||||
|
||||
END_OF_TEXT = 0
|
||||
END_OF_LINE_DOUBLE = 535
|
||||
|
||||
|
||||
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
||||
|
||||
|
||||
@@ -29,6 +23,8 @@ class RWKVType(Enum):
|
||||
|
||||
class AbstractRWKV(ABC):
|
||||
def __init__(self, model, pipeline):
|
||||
self.EOS_ID = 0
|
||||
|
||||
self.name = "rwkv"
|
||||
self.model = model
|
||||
self.pipeline = pipeline
|
||||
@@ -68,6 +64,8 @@ class AbstractRWKV(ABC):
|
||||
pass
|
||||
|
||||
def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
|
||||
import numpy as np
|
||||
|
||||
if fast_mode:
|
||||
embedding, token_len = self.__fast_embedding(
|
||||
self.fix_tokens(self.pipeline.encode(input)), None
|
||||
@@ -222,6 +220,8 @@ class AbstractRWKV(ABC):
|
||||
def generate(
|
||||
self, prompt: str, stop: Union[str, List[str], None] = None
|
||||
) -> Iterable[Tuple[str, str, int, int]]:
|
||||
import numpy as np
|
||||
|
||||
quick_log(None, None, "Generation Prompt:\n" + prompt)
|
||||
cache = None
|
||||
delta_prompt = prompt
|
||||
@@ -231,14 +231,14 @@ class AbstractRWKV(ABC):
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
if cache is None or cache["prompt"] == "":
|
||||
if cache is None or cache["prompt"] == "" or cache["state"] is None:
|
||||
self.model_state = None
|
||||
self.model_tokens = []
|
||||
else:
|
||||
delta_prompt = prompt[len(cache["prompt"]) :]
|
||||
self.model_state = copy.deepcopy(cache["state"])
|
||||
self.model_tokens = copy.deepcopy(cache["tokens"])
|
||||
logits = copy.deepcopy(cache["logits"])
|
||||
self.model_state = cache["state"]
|
||||
self.model_tokens = cache["tokens"]
|
||||
logits = cache["logits"]
|
||||
|
||||
prompt_token_len = 0
|
||||
if delta_prompt != "":
|
||||
@@ -271,7 +271,7 @@ class AbstractRWKV(ABC):
|
||||
logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
|
||||
)
|
||||
|
||||
if token == END_OF_TEXT:
|
||||
if token == self.EOS_ID:
|
||||
yield response, "", prompt_token_len, completion_token_len
|
||||
break
|
||||
|
||||
@@ -398,7 +398,7 @@ class TextRWKV(AbstractRWKV):
|
||||
def fix_tokens(self, tokens) -> List[int]:
|
||||
if self.rwkv_type == RWKVType.World:
|
||||
return tokens
|
||||
if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
|
||||
if len(tokens) > 0 and tokens[-1] == 535:
|
||||
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
|
||||
return tokens
|
||||
|
||||
@@ -456,7 +456,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
pass
|
||||
|
||||
|
||||
class MusicRWKV(AbstractRWKV):
|
||||
class MusicMidiRWKV(AbstractRWKV):
|
||||
def __init__(self, model, pipeline):
|
||||
super().__init__(model, pipeline)
|
||||
|
||||
@@ -498,8 +498,45 @@ class MusicRWKV(AbstractRWKV):
|
||||
return " " + delta
|
||||
|
||||
|
||||
class MusicAbcRWKV(AbstractRWKV):
|
||||
def __init__(self, model, pipeline):
|
||||
super().__init__(model, pipeline)
|
||||
|
||||
self.EOS_ID = 3
|
||||
|
||||
self.max_tokens_per_generation = 500
|
||||
self.temperature = 1
|
||||
self.top_p = 0.8
|
||||
self.top_k = 8
|
||||
|
||||
self.rwkv_type = RWKVType.Music
|
||||
|
||||
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||
pass
|
||||
|
||||
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
|
||||
pass
|
||||
|
||||
def fix_tokens(self, tokens) -> List[int]:
|
||||
return tokens
|
||||
|
||||
def run_rnn(
|
||||
self, _tokens: List[str], newline_adj: int = 0
|
||||
) -> Tuple[List[float], int]:
|
||||
tokens = [int(x) for x in _tokens]
|
||||
token_len = len(tokens)
|
||||
self.model_tokens += tokens
|
||||
out, self.model_state = self.model.forward(tokens, self.model_state)
|
||||
return out, token_len
|
||||
|
||||
def delta_postprocess(self, delta: str) -> str:
|
||||
return delta
|
||||
|
||||
|
||||
def get_tokenizer(tokenizer_len: int):
|
||||
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
|
||||
if tokenizer_len < 20096:
|
||||
return "abc_tokenizer"
|
||||
if tokenizer_len < 50277:
|
||||
return tokenizer_dir + "tokenizer-midi.json"
|
||||
elif tokenizer_len < 65536:
|
||||
@@ -510,12 +547,28 @@ def get_tokenizer(tokenizer_len: int):
|
||||
|
||||
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
|
||||
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
||||
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
|
||||
webgpu = global_var.get(global_var.Args).webgpu
|
||||
|
||||
if "midi" in model.lower() or "abc" in model.lower():
|
||||
os.environ["RWKV_RESCALE_LAYER"] = "999"
|
||||
|
||||
# dynamic import to make RWKV_CUDA_ON work
|
||||
if rwkv_beta:
|
||||
print("Using rwkv-beta")
|
||||
from rwkv_pip.beta.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
elif rwkv_cpp:
|
||||
print("Using rwkv.cpp, strategy is ignored")
|
||||
from rwkv_pip.cpp.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
elif webgpu:
|
||||
print("Using webgpu")
|
||||
from rwkv_pip.webgpu.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
else:
|
||||
from rwkv_pip.model import (
|
||||
RWKV as Model,
|
||||
@@ -531,7 +584,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
|
||||
rwkv_map: dict[str, Type[AbstractRWKV]] = {
|
||||
"20B_tokenizer": TextRWKV,
|
||||
"rwkv_vocab_v20230424": TextRWKV,
|
||||
"tokenizer-midi": MusicRWKV,
|
||||
"tokenizer-midi": MusicMidiRWKV,
|
||||
"abc_tokenizer": MusicAbcRWKV,
|
||||
}
|
||||
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
|
||||
rwkv: AbstractRWKV
|
||||
|
||||
5
build/darwin/Readme_Install.txt
vendored
5
build/darwin/Readme_Install.txt
vendored
@@ -1,3 +1,8 @@
|
||||
Client Download URL:
|
||||
客户端下载地址:
|
||||
クライアントのダウンロードURL:
|
||||
https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner_macos_universal.zip
|
||||
|
||||
For Mac and Linux users, please manually install Python 3.10 (usually the latest systems come with it built-in). You can specify the Python interpreter to use in Settings. (which python3)
|
||||
对于Mac和Linux用户,请手动安装 Python3.10 (通常最新的系统已经内置了). 你可以在设置中指定使用的Python解释器. (which python3)
|
||||
MacおよびLinuxのユーザーの方は、Python3.10を手動でインストールしてください(通常、最新のシステムには既に組み込まれています)。 設定メニューで使用するPythonインタプリタを指定することができます。 (which python3)
|
||||
|
||||
5
build/linux/Readme_Install.txt
vendored
5
build/linux/Readme_Install.txt
vendored
@@ -1,3 +1,8 @@
|
||||
Client Download URL:
|
||||
客户端下载地址:
|
||||
クライアントのダウンロードURL:
|
||||
https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner_linux_x64
|
||||
|
||||
For Mac and Linux users, please manually install Python 3.10 (usually the latest systems come with it built-in). You can specify the Python interpreter to use in Settings.
|
||||
对于Mac和Linux用户,请手动安装 Python3.10 (通常最新的系统已经内置了). 你可以在设置中指定使用的Python解释器.
|
||||
MacおよびLinuxのユーザーの方は、Python3.10を手動でインストールしてください(通常、最新のシステムには既に組み込まれています)。 設定メニューで使用するPythonインタプリタを指定することができます。
|
||||
|
||||
5
build/windows/Readme_Install.txt
vendored
5
build/windows/Readme_Install.txt
vendored
@@ -1,3 +1,8 @@
|
||||
Client Download URL:
|
||||
客户端下载地址:
|
||||
クライアントのダウンロードURL:
|
||||
https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner_windows_x64.exe
|
||||
|
||||
Please execute this program in an empty directory. All related dependencies will be placed in this directory.
|
||||
请将本程序放在一个空目录内执行, 所有相关依赖均会放置于此目录.
|
||||
このプログラムを空のディレクトリで実行してください。関連するすべての依存関係は、このディレクトリに配置されます。
|
||||
|
||||
@@ -19,14 +19,15 @@ document.querySelectorAll('.grid.h-10.grid-cols-12.place-content-center.gap-x-3.
|
||||
if (!data.name.endsWith('.bin') && !data.name.endsWith('.pth'))
|
||||
return
|
||||
|
||||
data.desc = {en: '', zh: ''}
|
||||
data.desc = { en: '', zh: '', ja: '' }
|
||||
const rawText = await (await fetch(e.children[1].href.replace('/resolve/', '/raw/'))).text()
|
||||
|
||||
data.size = parseInt(extractValue(rawText, 'size'))
|
||||
data.SHA256 = extractValue(rawText, 'oid sha256:')
|
||||
data.lastUpdated = e.children[3].children[0].getAttribute('datetime')
|
||||
data.url = e.children[1].href.replace('/resolve/', '/blob/')
|
||||
data.downloadUrl = e.children[1].href
|
||||
data.url = e.children[1].href.replace('/resolve/', '/blob/').replace('?download=true', '')
|
||||
data.downloadUrl = e.children[1].href.replace('?download=true', '')
|
||||
data.tags = []
|
||||
|
||||
modelsJson.push(data)
|
||||
})
|
||||
|
||||
@@ -23,6 +23,7 @@ def file_cleaner(file):
|
||||
return cleaner
|
||||
|
||||
|
||||
expected_max_version = float(sys.argv[2]) if len(sys.argv) > 2 else 100
|
||||
model_file = open(sys.argv[1], "rb")
|
||||
cleaner = file_cleaner(model_file)
|
||||
cleaner_thread = threading.Thread(target=cleaner, daemon=True)
|
||||
@@ -31,11 +32,30 @@ cleaner_thread.start()
|
||||
w = torch.load(model_file, map_location="cpu")
|
||||
gc.collect()
|
||||
|
||||
vocab_size = w["emb.weight"].shape[0]
|
||||
n_embd = w["emb.weight"].shape[1]
|
||||
n_layer = 0
|
||||
keys = list(w.keys())
|
||||
version = 4
|
||||
for x in keys:
|
||||
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
|
||||
n_layer = max(n_layer, layer_id + 1)
|
||||
|
||||
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
|
||||
if "ln_x" in x:
|
||||
version = max(5, version)
|
||||
if "gate.weight" in x:
|
||||
version = max(5.1, version)
|
||||
if int(version) == 5 and "att.time_decay" in x:
|
||||
if len(w[x].shape) > 1:
|
||||
if w[x].shape[1] > 1:
|
||||
version = max(5.2, version)
|
||||
if "time_maa" in x:
|
||||
version = max(6, version)
|
||||
|
||||
if version <= expected_max_version:
|
||||
print(
|
||||
f"v{int(version)}/train.py --vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}",
|
||||
end="",
|
||||
)
|
||||
else:
|
||||
raise Exception(f"RWKV{version} is not supported")
|
||||
|
||||
@@ -47,10 +47,10 @@ else
|
||||
fi
|
||||
|
||||
echo "loading $loadModel"
|
||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel)
|
||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 5.2)
|
||||
echo $modelInfo
|
||||
if [[ $modelInfo =~ "--n_layer" ]]; then
|
||||
python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
|
||||
python3 ./finetune/lora/$modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
|
||||
--lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu
|
||||
else
|
||||
echo "modelInfo is invalid"
|
||||
|
||||
@@ -7,6 +7,7 @@ import struct
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
|
||||
|
||||
def print_rank_0(*message):
|
||||
pass
|
||||
# """If distributed is initialized print only on rank 0."""
|
||||
@@ -16,12 +17,14 @@ def print_rank_0(*message):
|
||||
# else:
|
||||
# print(*message, flush=True)
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
pass
|
||||
# with open(path, "rb") as stream:
|
||||
# while stream.read(100 * 1024 * 1024):
|
||||
# pass
|
||||
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
@@ -33,18 +36,22 @@ dtypes = {
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
@@ -100,7 +107,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
@@ -217,8 +224,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError(
|
||||
"Slices into indexed_dataset must be contiguous")
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
@@ -17,9 +17,11 @@ class MyDataset(Dataset):
|
||||
|
||||
if args.data_type == "binidx":
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
|
||||
rank_zero_info(
|
||||
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||
)
|
||||
|
||||
if args.data_file.endswith('/'):
|
||||
if args.data_file.endswith("/"):
|
||||
d_all = []
|
||||
for p in os.listdir(args.data_file):
|
||||
if p.endswith(".idx"):
|
||||
@@ -29,33 +31,52 @@ class MyDataset(Dataset):
|
||||
exit(0)
|
||||
else:
|
||||
self.data = MMapIndexedDataset(args.data_file)
|
||||
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
|
||||
self.data_size = (
|
||||
len(self.data._bin_buffer) // self.data._index._dtype_size
|
||||
)
|
||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||
|
||||
if args.my_qa_mask > 0:
|
||||
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
|
||||
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
||||
self.data_pile = MMapIndexedDataset(
|
||||
"/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document"
|
||||
)
|
||||
self.data_pile_size = (
|
||||
len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
||||
)
|
||||
|
||||
if args.my_pile_stage > 0:
|
||||
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
||||
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||
assert self.samples_per_epoch == 40320
|
||||
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
|
||||
rank_zero_info(
|
||||
f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########"
|
||||
)
|
||||
dataset_slot = self.data_size // args.ctx_len
|
||||
if args.my_pile_stage != 4:
|
||||
assert MaybeIsPrime(args.magic_prime)
|
||||
assert args.magic_prime % 3 == 2
|
||||
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
|
||||
assert (
|
||||
args.magic_prime / dataset_slot > 0.99
|
||||
and args.magic_prime / dataset_slot <= 1
|
||||
)
|
||||
elif args.data_type == "numpy":
|
||||
self.data = np.load(args.data_file).astype("int")
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
||||
rank_zero_info(
|
||||
"Current vocab size =", self.vocab_size, "(make sure it's correct)"
|
||||
)
|
||||
self.data_size = len(self.data)
|
||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||
elif args.data_type == "uint16":
|
||||
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
|
||||
self.data = (
|
||||
np.fromfile(args.data_file, dtype=np.uint16)
|
||||
.astype("int32")
|
||||
.reshape(-1, args.my_sample_len)
|
||||
)
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
||||
rank_zero_info(
|
||||
"Current vocab size =", self.vocab_size, "(make sure it's correct)"
|
||||
)
|
||||
self.data_size = self.data.shape[0]
|
||||
rank_zero_info(f"Data has {self.data_size} samples.")
|
||||
elif args.data_type == "wds_img":
|
||||
@@ -86,10 +107,14 @@ class MyDataset(Dataset):
|
||||
for u in unique:
|
||||
xxObj[xx] = u
|
||||
xx += 1
|
||||
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
|
||||
with open(
|
||||
f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le"
|
||||
) as vocab_file:
|
||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||
self.data_size = len(self.data)
|
||||
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
|
||||
rank_zero_info(
|
||||
f"Data has {self.data_size} tokens, {self.vocab_size} vocab size."
|
||||
)
|
||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||
|
||||
@@ -104,36 +129,53 @@ class MyDataset(Dataset):
|
||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
||||
|
||||
if args.data_type == "wds_img":
|
||||
|
||||
def init_wds(self, bias=0):
|
||||
def identity(x):
|
||||
return x
|
||||
return x
|
||||
|
||||
import webdataset as wds
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
# img_transform = transforms.Compose(
|
||||
# [transforms.CenterCrop(256)]
|
||||
# )
|
||||
img_transform = transforms.Compose([
|
||||
transforms.CenterCrop(512),
|
||||
transforms.Resize((args.my_img_size))
|
||||
])
|
||||
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
|
||||
img_transform = transforms.Compose(
|
||||
[transforms.CenterCrop(512), transforms.Resize((args.my_img_size))]
|
||||
)
|
||||
self.data_raw = (
|
||||
wds.WebDataset(args.data_file, resampled=True)
|
||||
.shuffle(
|
||||
10000,
|
||||
initial=1000,
|
||||
rng=random.Random(epoch * 100000 + rank + bias * 1e9),
|
||||
)
|
||||
.decode("torchrgb")
|
||||
.to_tuple("jpg", "json", "txt")
|
||||
.map_tuple(img_transform, identity, identity)
|
||||
)
|
||||
for pp in self.data_raw.pipeline:
|
||||
if 'Resampled' in str(pp):
|
||||
if "Resampled" in str(pp):
|
||||
pp.deterministic = True
|
||||
|
||||
def worker_seed():
|
||||
return rank*100000+epoch+bias*1e9
|
||||
return rank * 100000 + epoch + bias * 1e9
|
||||
|
||||
pp.worker_seed = worker_seed
|
||||
self.data = iter(self.data_raw)
|
||||
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
|
||||
|
||||
if self.data == None:
|
||||
init_wds(self)
|
||||
trial = 0
|
||||
while trial < 10:
|
||||
try:
|
||||
dd = next(self.data) # jpg, json, txt
|
||||
dd = next(self.data) # jpg, json, txt
|
||||
break
|
||||
except:
|
||||
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
|
||||
print(
|
||||
f"[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]"
|
||||
)
|
||||
self.error_count += 1
|
||||
init_wds(self, self.error_count)
|
||||
trial += 1
|
||||
@@ -144,7 +186,7 @@ class MyDataset(Dataset):
|
||||
return dd[0], dd[2]
|
||||
else:
|
||||
if args.data_type == "uint16":
|
||||
i = np.random.randint(0, self.data_size-1)
|
||||
i = np.random.randint(0, self.data_size - 1)
|
||||
dix = self.data[i]
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||
@@ -196,7 +238,12 @@ class MyDataset(Dataset):
|
||||
z_sum = 0
|
||||
isGood = False
|
||||
for i in range(3, ctx_len):
|
||||
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
|
||||
if (
|
||||
dix[i] == 27
|
||||
and dix[i - 1] == 34
|
||||
and dix[i - 2] == 187
|
||||
and dix[i - 3] == 187
|
||||
):
|
||||
isGood = True
|
||||
if dix[i] == 0:
|
||||
isGood = False
|
||||
@@ -206,7 +253,9 @@ class MyDataset(Dataset):
|
||||
if z_sum == 0:
|
||||
z = [1] * ctx_len
|
||||
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
|
||||
dix = self.data_pile.get(
|
||||
idx=0, offset=i, length=req_len
|
||||
).astype(int)
|
||||
z = torch.tensor(z, dtype=torch.bfloat16)
|
||||
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
@@ -5,6 +5,7 @@
|
||||
import functools
|
||||
import os, math, gc, importlib
|
||||
import torch
|
||||
|
||||
# torch._C._jit_set_profiling_executor(True)
|
||||
# torch._C._jit_set_profiling_mode(True)
|
||||
import torch.nn as nn
|
||||
@@ -13,7 +14,8 @@ from torch.nn import functional as F
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||
if importlib.util.find_spec('deepspeed'):
|
||||
|
||||
if importlib.util.find_spec("deepspeed"):
|
||||
import deepspeed
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
||||
|
||||
@@ -28,9 +30,10 @@ LORA_CONFIG = {
|
||||
|
||||
|
||||
try:
|
||||
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
|
||||
print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"])
|
||||
except:
|
||||
os.environ["RWKV_MY_TESTING"] = ''
|
||||
os.environ["RWKV_MY_TESTING"] = ""
|
||||
|
||||
|
||||
def __nop(ob):
|
||||
return ob
|
||||
@@ -53,7 +56,26 @@ T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["finetune/lora/cuda/wkv_op_bf16.cpp", "finetune/lora/cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
||||
wkv_cuda = load(
|
||||
name=f"wkv_{T_MAX}_bf16",
|
||||
sources=[
|
||||
"finetune/lora/v4/cuda/wkv_op_bf16.cpp",
|
||||
"finetune/lora/v4/cuda/wkv_cuda_bf16.cu",
|
||||
],
|
||||
verbose=True,
|
||||
extra_cuda_cflags=[
|
||||
"-t 4",
|
||||
"-std=c++17",
|
||||
"-res-usage",
|
||||
"--maxrregcount 60",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"-Xptxas -O3",
|
||||
"--extra-device-vectorization",
|
||||
f"-DTmax={T_MAX}",
|
||||
],
|
||||
)
|
||||
|
||||
class WKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, w, u, k, v):
|
||||
@@ -66,10 +88,16 @@ if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
u = u.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
y = torch.empty(
|
||||
(B, T, C),
|
||||
device=w.device,
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||
ctx.save_for_backward(w, u, k, v, y)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
B = ctx.B
|
||||
@@ -78,16 +106,54 @@ if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
w, u, k, v, y = ctx.saved_tensors
|
||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gw = torch.empty(
|
||||
(B, C),
|
||||
device=gy.device,
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
gu = torch.empty(
|
||||
(B, C),
|
||||
device=gy.device,
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
gk = torch.empty(
|
||||
(B, T, C),
|
||||
device=gy.device,
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
gv = torch.empty(
|
||||
(B, T, C),
|
||||
device=gy.device,
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
||||
gw = torch.sum(gw, dim=0)
|
||||
gu = torch.sum(gu, dim=0)
|
||||
return (None, None, None, gw, gu, gk, gv)
|
||||
|
||||
else:
|
||||
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["finetune/lora/cuda/wkv_op.cpp", "finetune/lora/cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
||||
wkv_cuda = load(
|
||||
name=f"wkv_{T_MAX}",
|
||||
sources=[
|
||||
"finetune/lora/v4/cuda/wkv_op.cpp",
|
||||
"finetune/lora/v4/cuda/wkv_cuda.cu",
|
||||
],
|
||||
verbose=True,
|
||||
extra_cuda_cflags=[
|
||||
"-res-usage",
|
||||
"--maxrregcount 60",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"-Xptxas -O3",
|
||||
"--extra-device-vectorization",
|
||||
f"-DTmax={T_MAX}",
|
||||
],
|
||||
)
|
||||
|
||||
class WKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, w, u, k, v):
|
||||
@@ -106,7 +172,9 @@ else:
|
||||
u = u.float().contiguous()
|
||||
k = k.float().contiguous()
|
||||
v = v.float().contiguous()
|
||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
|
||||
y = torch.empty(
|
||||
(B, T, C), device=w.device, memory_format=torch.contiguous_format
|
||||
)
|
||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||
ctx.save_for_backward(w, u, k, v, y)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
@@ -115,6 +183,7 @@ else:
|
||||
return y.half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
return y.bfloat16()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
B = ctx.B
|
||||
@@ -123,14 +192,26 @@ else:
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
w, u, k, v, y = ctx.saved_tensors
|
||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gw = torch.empty(
|
||||
(B, C), device=gy.device, memory_format=torch.contiguous_format
|
||||
)
|
||||
gu = torch.empty(
|
||||
(B, C), device=gy.device, memory_format=torch.contiguous_format
|
||||
)
|
||||
gk = torch.empty(
|
||||
(B, T, C), device=gy.device, memory_format=torch.contiguous_format
|
||||
)
|
||||
gv = torch.empty(
|
||||
(B, T, C), device=gy.device, memory_format=torch.contiguous_format
|
||||
)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
||||
wkv_cuda.backward(
|
||||
B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv
|
||||
)
|
||||
else:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
|
||||
wkv_cuda.backward(
|
||||
B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv
|
||||
)
|
||||
gw = torch.sum(gw, dim=0)
|
||||
gu = torch.sum(gu, dim=0)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
@@ -138,7 +219,15 @@ else:
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
gw.bfloat16(),
|
||||
gu.bfloat16(),
|
||||
gk.bfloat16(),
|
||||
gv.bfloat16(),
|
||||
)
|
||||
|
||||
|
||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
||||
@@ -151,15 +240,17 @@ def RUN_CUDA(B, T, C, w, u, k, v):
|
||||
|
||||
|
||||
class LoraLinear(nn.Module):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||
assert bias == False, "Biased LoraLinear not supported"
|
||||
|
||||
r, alpha, dropout = LORA_CONFIG["r"], LORA_CONFIG[
|
||||
"alpha"], LORA_CONFIG["dropout"]
|
||||
r, alpha, dropout = (
|
||||
LORA_CONFIG["r"],
|
||||
LORA_CONFIG["alpha"],
|
||||
LORA_CONFIG["dropout"],
|
||||
)
|
||||
self.lora_A = nn.Parameter(torch.empty(r, in_features))
|
||||
self.lora_B = nn.Parameter(torch.empty(out_features, r))
|
||||
self.lora_dropout = nn.Dropout(dropout)
|
||||
@@ -170,9 +261,9 @@ class LoraLinear(nn.Module):
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
F.linear(x, self.weight) + self.scaling *
|
||||
F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B))
|
||||
return F.linear(x, self.weight) + self.scaling * F.linear(
|
||||
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
|
||||
)
|
||||
|
||||
|
||||
@functools.wraps(LoraLinear)
|
||||
@@ -214,17 +305,23 @@ class RWKV_TimeMix(MyModule):
|
||||
# fancy time_decay
|
||||
decay_speed = torch.ones(args.dim_att)
|
||||
for h in range(args.dim_att):
|
||||
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
||||
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (
|
||||
0.7 + 1.3 * ratio_0_to_1
|
||||
)
|
||||
self.time_decay = nn.Parameter(decay_speed)
|
||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||||
|
||||
# fancy time_first
|
||||
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
|
||||
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
|
||||
self.time_first = nn.Parameter(
|
||||
torch.ones(args.dim_att) * math.log(0.3) + zigzag
|
||||
)
|
||||
|
||||
# fancy time_mix
|
||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||
self.time_mix_v = nn.Parameter(
|
||||
torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||
)
|
||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
@@ -235,8 +332,10 @@ class RWKV_TimeMix(MyModule):
|
||||
|
||||
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
||||
|
||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
||||
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
if "a" in os.environ["RWKV_MY_TESTING"]:
|
||||
self.register_buffer(
|
||||
"att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||
)
|
||||
d_qkv = args.n_embd // 16
|
||||
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||
@@ -245,12 +344,17 @@ class RWKV_TimeMix(MyModule):
|
||||
with torch.no_grad():
|
||||
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||
self.time_mix_vv = nn.Parameter(
|
||||
torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||
)
|
||||
|
||||
if "a" not in os.environ["RWKV_MY_TESTING"]:
|
||||
|
||||
if 'a' not in os.environ["RWKV_MY_TESTING"]:
|
||||
@MyFunction
|
||||
def jit_func(self, x):
|
||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xx = self.time_shift(
|
||||
x
|
||||
) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
@@ -263,21 +367,26 @@ class RWKV_TimeMix(MyModule):
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||
sr, k, v = self.jit_func(x)
|
||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
||||
rwkv = sr * RUN_CUDA(
|
||||
B, T, self.args.dim_att, self.time_decay, self.time_first, k, v
|
||||
)
|
||||
return self.output(rwkv)
|
||||
|
||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
||||
if "a" in os.environ["RWKV_MY_TESTING"]:
|
||||
|
||||
@MyFunction
|
||||
def QKV(self, q, k, v):
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.att_mask == 0, float('-inf'))
|
||||
att = F.softmax(att, dim = -1)
|
||||
att = att.masked_fill(self.att_mask == 0, float("-inf"))
|
||||
att = F.softmax(att, dim=-1)
|
||||
x = att @ v
|
||||
return x
|
||||
|
||||
@MyFunction
|
||||
def jit_funcQKV(self, x):
|
||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xx = self.time_shift(
|
||||
x
|
||||
) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
@@ -296,12 +405,16 @@ class RWKV_TimeMix(MyModule):
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
|
||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
||||
rwkv = sr * RUN_CUDA(
|
||||
B, T, self.args.dim_att, self.time_decay, self.time_first, k, v
|
||||
)
|
||||
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
|
||||
return rwkv
|
||||
|
||||
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class RWKV_ChannelMix(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
@@ -331,6 +444,7 @@ class RWKV_ChannelMix(MyModule):
|
||||
kv = self.value(k)
|
||||
return torch.sigmoid(self.receptance(xr)) * kv
|
||||
|
||||
|
||||
class MishGLU(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
@@ -360,6 +474,7 @@ class MishGLU(MyModule):
|
||||
b = self.bb(xb)
|
||||
return self.value(a * F.mish(b))
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# The RWKV Model with our blocks
|
||||
########################################################################################################
|
||||
@@ -377,15 +492,19 @@ class Block(nn.Module):
|
||||
if self.layer_id == 0:
|
||||
self.ln0 = nn.LayerNorm(args.n_embd)
|
||||
if args.my_pos_emb > 0:
|
||||
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
||||
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
||||
self.pos_emb_x = nn.Parameter(
|
||||
torch.zeros((1, args.my_pos_emb, args.n_embd))
|
||||
)
|
||||
self.pos_emb_y = nn.Parameter(
|
||||
torch.zeros((args.my_pos_emb, 1, args.n_embd))
|
||||
)
|
||||
|
||||
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
||||
self.ffnPre = RWKV_ChannelMix(args, 0)
|
||||
else:
|
||||
self.att = RWKV_TimeMix(args, layer_id)
|
||||
|
||||
if 'g' in os.environ["RWKV_MY_TESTING"]:
|
||||
if "g" in os.environ["RWKV_MY_TESTING"]:
|
||||
self.ffn = MishGLU(args, layer_id)
|
||||
else:
|
||||
self.ffn = RWKV_ChannelMix(args, layer_id)
|
||||
@@ -395,7 +514,9 @@ class Block(nn.Module):
|
||||
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
||||
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
self.register_buffer(
|
||||
"tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||
)
|
||||
|
||||
def forward(self, x, x_emb=None):
|
||||
args = self.args
|
||||
@@ -403,7 +524,7 @@ class Block(nn.Module):
|
||||
if self.layer_id == 0:
|
||||
x = self.ln0(x)
|
||||
if args.my_pos_emb > 0:
|
||||
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
|
||||
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :]
|
||||
x = x + pos_emb
|
||||
|
||||
if self.layer_id == 0 and args.pre_ffn > 0:
|
||||
@@ -443,13 +564,13 @@ class RWKV(pl.LightningModule):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
if not hasattr(args, 'dim_att'):
|
||||
if not hasattr(args, "dim_att"):
|
||||
args.dim_att = args.n_embd
|
||||
if not hasattr(args, 'dim_ffn'):
|
||||
if not hasattr(args, "dim_ffn"):
|
||||
args.dim_ffn = args.n_embd * 4
|
||||
if not hasattr(args, 'tiny_att_layer'):
|
||||
if not hasattr(args, "tiny_att_layer"):
|
||||
args.tiny_att_layer = -1
|
||||
if not hasattr(args, 'tiny_att_dim'):
|
||||
if not hasattr(args, "tiny_att_dim"):
|
||||
args.tiny_att_dim = -1
|
||||
|
||||
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
||||
@@ -462,7 +583,9 @@ class RWKV(pl.LightningModule):
|
||||
if args.head_qk > 0:
|
||||
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
self.register_buffer(
|
||||
"copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||
)
|
||||
|
||||
def configure_optimizers(self):
|
||||
args = self.args
|
||||
@@ -494,19 +617,46 @@ class RWKV(pl.LightningModule):
|
||||
param_dict = {n: p for n, p in self.named_parameters()}
|
||||
if args.my_pile_stage == 2:
|
||||
optim_groups = [
|
||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
|
||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_1x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 1.0,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_2x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 5.0,
|
||||
}, # test: 2e-3 / args.lr_init},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_3x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 5.0,
|
||||
}, # test: 3e-3 / args.lr_init},
|
||||
]
|
||||
else:
|
||||
optim_groups = [
|
||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_1x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 1.0,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_2x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 2.0,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_3x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 3.0,
|
||||
},
|
||||
]
|
||||
else:
|
||||
optim_groups = [
|
||||
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
||||
{
|
||||
"params": [p for n, p in self.named_parameters()],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
for g in optim_groups:
|
||||
@@ -514,8 +664,26 @@ class RWKV(pl.LightningModule):
|
||||
optim_groups = [g for g in optim_groups if len(g["params"]) > 0]
|
||||
|
||||
if self.deepspeed_offload:
|
||||
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
|
||||
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
||||
return DeepSpeedCPUAdam(
|
||||
optim_groups,
|
||||
lr=self.args.lr_init,
|
||||
betas=self.args.betas,
|
||||
eps=self.args.adam_eps,
|
||||
bias_correction=True,
|
||||
adamw_mode=False,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
)
|
||||
return FusedAdam(
|
||||
optim_groups,
|
||||
lr=self.args.lr_init,
|
||||
betas=self.args.betas,
|
||||
eps=self.args.adam_eps,
|
||||
bias_correction=True,
|
||||
adam_w_mode=False,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
)
|
||||
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
||||
|
||||
@property
|
||||
@@ -589,10 +757,14 @@ class RWKV(pl.LightningModule):
|
||||
|
||||
logits = self(idx)
|
||||
if sum_mask == mask.shape[0]:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), targets.view(-1)
|
||||
)
|
||||
# print('rank', self.global_rank, 'loss', loss.item())
|
||||
else:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none"
|
||||
)
|
||||
# loss_raw = loss
|
||||
loss = torch.sum(loss * mask) / sum_mask
|
||||
|
||||
@@ -632,7 +804,14 @@ class RWKV(pl.LightningModule):
|
||||
|
||||
gain = 1.0
|
||||
scale = 1.0
|
||||
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
|
||||
if (
|
||||
"ln_" in n
|
||||
or ".ln" in n
|
||||
or "time_" in n
|
||||
or "_mask" in n
|
||||
or "pos_emb" in n
|
||||
or ".mask." in n
|
||||
):
|
||||
m[n] = p
|
||||
else:
|
||||
if n == "emb.weight":
|
||||
@@ -640,7 +819,19 @@ class RWKV(pl.LightningModule):
|
||||
else:
|
||||
if shape[0] > shape[1]:
|
||||
gain = math.sqrt(shape[0] / shape[1])
|
||||
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
|
||||
for kk in [
|
||||
".att.key.",
|
||||
".att.receptance.",
|
||||
".att.output.",
|
||||
".att.key.",
|
||||
".ffn.value.",
|
||||
".ffn.receptance.",
|
||||
".ffnPre.value.",
|
||||
".ffnPre.receptance.",
|
||||
"head_q.",
|
||||
".oo.",
|
||||
".rr.",
|
||||
]:
|
||||
if kk in n:
|
||||
scale = 0
|
||||
if n == "head.weight":
|
||||
@@ -650,7 +841,9 @@ class RWKV(pl.LightningModule):
|
||||
if "head_q." in n:
|
||||
scale = 0
|
||||
|
||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
|
||||
print(
|
||||
f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}"
|
||||
)
|
||||
|
||||
if self.args.accelerator.upper() == "GPU":
|
||||
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
||||
@@ -5,15 +5,17 @@ import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from .model import LORA_CONFIG
|
||||
|
||||
|
||||
def my_save(dd, ff):
|
||||
if '14b-run1' not in ff:
|
||||
if "14b-run1" not in ff:
|
||||
torch.save(dd, ff)
|
||||
else:
|
||||
fn = ff.split('/')[-1]
|
||||
fff = '/dev/shm/' + fn
|
||||
fn = ff.split("/")[-1]
|
||||
fff = "/dev/shm/" + fn
|
||||
torch.save(dd, fff)
|
||||
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
||||
|
||||
|
||||
class train_callback(pl.Callback):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
@@ -38,7 +40,9 @@ class train_callback(pl.Callback):
|
||||
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
||||
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
||||
else: # exp decay
|
||||
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
||||
lr = args.lr_init * math.exp(
|
||||
math.log(args.lr_final / args.lr_init) * pow(progress, 1)
|
||||
)
|
||||
|
||||
if trainer.global_step < w_step:
|
||||
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
||||
@@ -60,7 +64,9 @@ class train_callback(pl.Callback):
|
||||
trainer.my_loss_sum = 0
|
||||
trainer.my_loss_count = 0
|
||||
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
||||
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
|
||||
trainer.my_log.write(
|
||||
f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n"
|
||||
)
|
||||
try:
|
||||
print(f"\n{trainer.strategy.config}\n")
|
||||
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
||||
@@ -70,6 +76,7 @@ class train_callback(pl.Callback):
|
||||
if len(args.wandb) > 0:
|
||||
print("Login to wandb...")
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project=args.wandb,
|
||||
name=args.run_name + " " + args.my_timestamp,
|
||||
@@ -102,20 +109,26 @@ class train_callback(pl.Callback):
|
||||
# self.log("s", real_step, prog_bar=True, on_step=True)
|
||||
|
||||
if len(args.wandb) > 0:
|
||||
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
|
||||
lll = {
|
||||
"loss": trainer.my_loss,
|
||||
"lr": trainer.my_lr,
|
||||
"Gtokens": real_step * token_per_step / 1e9,
|
||||
}
|
||||
if kt_s > 0:
|
||||
lll["kt/s"] = kt_s
|
||||
trainer.my_wandb.log(lll, step=int(real_step))
|
||||
if args.magic_prime > 0:
|
||||
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
||||
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
|
||||
if (
|
||||
int(real_step)
|
||||
== int(args.magic_prime * expand_factor // args.real_bsz) - 1
|
||||
):
|
||||
to_save_dict = pl_module.state_dict()
|
||||
my_save(
|
||||
to_save_dict,
|
||||
f"{args.proj_dir}/rwkv-final.pth",
|
||||
)
|
||||
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
args = self.args
|
||||
dataset = trainer.train_dataloader.dataset.datasets
|
||||
@@ -128,24 +141,28 @@ class train_callback(pl.Callback):
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
args = self.args
|
||||
if trainer.is_global_zero: # logging & save state_dict
|
||||
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
|
||||
if args.data_type == 'wds_img':
|
||||
if (
|
||||
args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0
|
||||
) or trainer.current_epoch == args.epoch_count - 1:
|
||||
if args.data_type == "wds_img":
|
||||
raw_dict = pl_module.state_dict()
|
||||
to_save_dict = {}
|
||||
for k in raw_dict:
|
||||
if k.startswith('encoder.') or k.startswith('decoder.'):
|
||||
if k.startswith("encoder.") or k.startswith("decoder."):
|
||||
to_save_dict[k] = raw_dict[k]
|
||||
else:
|
||||
to_save_dict = pl_module.state_dict()
|
||||
|
||||
if args.lora:
|
||||
enable_time_finetune = 'time' in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
|
||||
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||
lora_dict = {}
|
||||
for name, state in to_save_dict.items():
|
||||
if ('.lora_' in name
|
||||
or (enable_time_finetune and '.time_' in name)
|
||||
or (enable_ln_finetune and '.ln' in name)):
|
||||
if (
|
||||
".lora_" in name
|
||||
or (enable_time_finetune and ".time_" in name)
|
||||
or (enable_ln_finetune and ".ln" in name)
|
||||
):
|
||||
lora_dict[name] = state
|
||||
to_save_dict = lora_dict
|
||||
|
||||
@@ -155,8 +172,10 @@ class train_callback(pl.Callback):
|
||||
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
||||
)
|
||||
except Exception as e:
|
||||
print('Error\n\n', e, '\n\n')
|
||||
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
|
||||
print("Error\n\n", e, "\n\n")
|
||||
trainer.my_log.write(
|
||||
f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n"
|
||||
)
|
||||
trainer.my_log.flush()
|
||||
|
||||
trainer.my_loss_sum = 0
|
||||
@@ -178,22 +197,22 @@ def generate_init_weight(model, init_weight_name):
|
||||
mm[k] = src.reshape(mm[k].shape)
|
||||
except:
|
||||
tmp = mm[k].squeeze().clone()
|
||||
print(k, src.shape, '-->', mm[k].shape)
|
||||
print(k, src.shape, "-->", mm[k].shape)
|
||||
ss = src.shape[0]
|
||||
dd = tmp.shape[0]
|
||||
for i in range(dd):
|
||||
pos = i / dd * ss
|
||||
if pos >= ss - 1:
|
||||
tmp[i] = src[ss-1]
|
||||
tmp[i] = src[ss - 1]
|
||||
else:
|
||||
p0 = int(math.floor(pos))
|
||||
ii = pos - p0
|
||||
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
|
||||
tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii)
|
||||
mm[k] = tmp.reshape(mm[k].shape)
|
||||
sss = src.squeeze().float().cpu().numpy()
|
||||
print(sss[:10], '...', sss[-10:])
|
||||
print(sss[:10], "...", sss[-10:])
|
||||
mmm = mm[k].squeeze().float().cpu().numpy()
|
||||
print(mmm[:10], '...', mmm[-10:])
|
||||
print(mmm[:10], "...", mmm[-10:])
|
||||
|
||||
print(f"Save to {init_weight_name}...")
|
||||
torch.save(mm, init_weight_name)
|
||||
@@ -6,6 +6,7 @@ from torch.nn import functional as F
|
||||
time_slot = {}
|
||||
time_ref = time.time_ns()
|
||||
|
||||
|
||||
def record_time(name):
|
||||
if name not in time_slot:
|
||||
time_slot[name] = 1e20
|
||||
@@ -13,20 +14,23 @@ def record_time(name):
|
||||
if tt < time_slot[name]:
|
||||
time_slot[name] = tt
|
||||
|
||||
class TOKENIZER():
|
||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
||||
if 'list' in str(type(WORD_NAME)):
|
||||
|
||||
class TOKENIZER:
|
||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"):
|
||||
if "list" in str(type(WORD_NAME)):
|
||||
self.charMode = False
|
||||
if WORD_NAME[0] == WORD_NAME[1]:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
||||
else:
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
||||
self.vocab_size = len(self.tokenizer)
|
||||
else:
|
||||
self.charMode = True
|
||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
||||
with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file:
|
||||
self.word_table = json.load(result_file)
|
||||
|
||||
self.vocab_size = len(self.word_table)
|
||||
@@ -37,23 +41,25 @@ class TOKENIZER():
|
||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||
|
||||
def refine_context(self, context):
|
||||
context = context.strip().split('\n')
|
||||
context = context.strip().split("\n")
|
||||
for c in range(len(context)):
|
||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
||||
context = list(filter(lambda c: c != '', context))
|
||||
context = '\n' + ('\n'.join(context)).strip()
|
||||
if context == '':
|
||||
context = '\n'
|
||||
context[c] = context[c].strip().strip("\u3000").strip("\r")
|
||||
context = list(filter(lambda c: c != "", context))
|
||||
context = "\n" + ("\n".join(context)).strip()
|
||||
if context == "":
|
||||
context = "\n"
|
||||
return context
|
||||
|
||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
||||
def sample_logits(
|
||||
self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None
|
||||
):
|
||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||
lastChar = int(x[-1])
|
||||
|
||||
probs = F.softmax(out, dim=-1)
|
||||
|
||||
if self.charMode:
|
||||
if self.itos[lastChar] == '\n':
|
||||
if self.itos[lastChar] == "\n":
|
||||
top_p = top_p_newline
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
@@ -81,6 +87,7 @@ class TOKENIZER():
|
||||
out = torch.multinomial(probs, num_samples=1)[0]
|
||||
return out
|
||||
|
||||
|
||||
def MaybeIsPrime(number):
|
||||
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
||||
return True
|
||||
@@ -121,7 +128,9 @@ def MillerRabinPrimalityTest(number):
|
||||
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
||||
iterationNumber = 1
|
||||
|
||||
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
|
||||
while (iterationNumber <= timesTwoDividNumber - 1) and (
|
||||
randomNumberWithPower != number - 1
|
||||
):
|
||||
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
||||
iterationNumber = iterationNumber + 1
|
||||
if randomNumberWithPower != (number - 1):
|
||||
@@ -184,7 +184,7 @@ if __name__ == "__main__":
|
||||
args.num_sanity_val_steps = 0
|
||||
args.check_val_every_n_epoch = int(1e20)
|
||||
args.log_every_n_steps = int(1e20)
|
||||
args.max_epochs = args.epoch_count # continue forever
|
||||
args.max_epochs = args.epoch_count # -1 continue forever
|
||||
args.betas = (args.beta1, args.beta2)
|
||||
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
||||
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
||||
@@ -373,7 +373,7 @@ if __name__ == "__main__":
|
||||
for param in module.parameters():
|
||||
param.requires_grad = True
|
||||
elif enable_time_finetune and any(
|
||||
n.startswith("time") for n, _ in module.named_parameters()
|
||||
n.startswith("time") for n, _ in module.named_parameters()
|
||||
):
|
||||
for pname, param in module.named_parameters():
|
||||
if pname.startswith("time"):
|
||||
@@ -381,7 +381,7 @@ if __name__ == "__main__":
|
||||
param.requires_grad = True
|
||||
|
||||
if (
|
||||
len(args.load_model) == 0 or args.my_pile_stage == 1
|
||||
len(args.load_model) == 0 or args.my_pile_stage == 1
|
||||
): # shall we build the initial weights?
|
||||
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
||||
generate_init_weight(model, init_weight_name) # save initial weights
|
||||
@@ -423,8 +423,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if (
|
||||
args.lr_init > 1e-4
|
||||
or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
|
||||
args.lr_init > 1e-4
|
||||
or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
|
||||
):
|
||||
if "I_KNOW_WHAT_IM_DOING" in os.environ:
|
||||
if trainer.global_rank == 0:
|
||||
@@ -459,10 +459,10 @@ if __name__ == "__main__":
|
||||
|
||||
if "deepspeed" in args.strategy:
|
||||
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
)
|
||||
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
)
|
||||
|
||||
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
||||
202
finetune/lora/v5/cuda/wkv5_cuda.cu
vendored
Normal file
202
finetune/lora/v5/cuda/wkv5_cuda.cu
vendored
Normal file
@@ -0,0 +1,202 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
#include "ATen/ATen.h"
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
|
||||
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
||||
F *__restrict__ const _y)
|
||||
{
|
||||
const int b = blockIdx.x / H;
|
||||
const int h = blockIdx.x % H;
|
||||
const int i = threadIdx.x;
|
||||
_w += h*_N_;
|
||||
_u += h*_N_;
|
||||
|
||||
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
||||
float state[_N_] = {0};
|
||||
|
||||
__syncthreads();
|
||||
w[i] = _w[i];
|
||||
u[i] = float(_u[i]);
|
||||
__syncthreads();
|
||||
|
||||
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
||||
{
|
||||
__syncthreads();
|
||||
r[i] = float(_r[t]);
|
||||
k[i] = float(_k[t]);
|
||||
__syncthreads();
|
||||
|
||||
const float v = float(_v[t]);
|
||||
float y = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < _N_; j+=4)
|
||||
{
|
||||
const float4& r_ = (float4&)(r[j]);
|
||||
const float4& k_ = (float4&)(k[j]);
|
||||
const float4& w_ = (float4&)(w[j]);
|
||||
const float4& u_ = (float4&)(u[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 x;
|
||||
|
||||
x.x = k_.x * v;
|
||||
x.y = k_.y * v;
|
||||
x.z = k_.z * v;
|
||||
x.w = k_.w * v;
|
||||
|
||||
y += r_.x * (u_.x * x.x + s.x);
|
||||
y += r_.y * (u_.y * x.y + s.y);
|
||||
y += r_.z * (u_.z * x.z + s.z);
|
||||
y += r_.w * (u_.w * x.w + s.w);
|
||||
|
||||
s.x = s.x * w_.x + x.x;
|
||||
s.y = s.y * w_.y + x.y;
|
||||
s.z = s.z * w_.z + x.z;
|
||||
s.w = s.w * w_.w + x.w;
|
||||
}
|
||||
_y[t] = F(y);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_backward(const int B, const int T, const int C, const int H,
|
||||
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
|
||||
F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
|
||||
{
|
||||
const int b = blockIdx.x / H;
|
||||
const int h = blockIdx.x % H;
|
||||
const int i = threadIdx.x;
|
||||
_w += h*_N_;
|
||||
_u += h*_N_;
|
||||
__w += h*_N_;
|
||||
|
||||
__shared__ float w_[_N_], u_[_N_];
|
||||
__shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
|
||||
__syncthreads();
|
||||
w_[i] = _w[i];
|
||||
u_[i] = float(_u[i]);
|
||||
__syncthreads();
|
||||
|
||||
const float w = w_[i];
|
||||
const float ww = __w[i];
|
||||
const float u = u_[i];
|
||||
|
||||
float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
|
||||
|
||||
float gw = 0, gu = 0;
|
||||
const int t000 = b*T*C + h*_N_ + i;
|
||||
const int t111 = (b+1)*T*C + h*_N_ + i;
|
||||
const int t222 = t111 - 2*C;
|
||||
|
||||
for (int t = t000; t < t111; t += C)
|
||||
{
|
||||
__syncthreads();
|
||||
v[i] = float(_v[t]);
|
||||
gy[i] = float(_gy[t]);
|
||||
__syncthreads();
|
||||
|
||||
const float k = float(_k[t]);
|
||||
float gr = 0, gu_ = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < _N_; j++)
|
||||
{
|
||||
float& s = state[j];
|
||||
float x = k * v[j];
|
||||
|
||||
gr += (u * x + s) * gy[j];
|
||||
gu_ += x * gy[j];
|
||||
s = s * w + x;
|
||||
}
|
||||
_gr[t] = F(gr);
|
||||
gu += float(_r[t]) * gu_;
|
||||
}
|
||||
_gu[b*C + h*_N_ + i] = F(gu);
|
||||
|
||||
for (int t = t000; t < t222; t += C)
|
||||
{
|
||||
__syncthreads();
|
||||
v[i] = float(_v[t]);
|
||||
gy[i] = float(_gy[t + 2*C]);
|
||||
__syncthreads();
|
||||
|
||||
const float k = float(_k[t]);
|
||||
float gw_ = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < _N_; j++)
|
||||
{
|
||||
float& s = saaaa[j];
|
||||
float& s2 = sbbbb[j];
|
||||
float x = k * v[j];
|
||||
|
||||
float tmp = w * (x + s);
|
||||
s = tmp;
|
||||
s2 = tmp + w * s2;
|
||||
gw_ += s2 * gy[j];
|
||||
}
|
||||
gw += float(_r[t + 2*C]) * gw_;
|
||||
}
|
||||
_gw[b*C + h*_N_ + i] = F(ww * gw);
|
||||
|
||||
for (int t = t111 - C; t >= t000; t -= C)
|
||||
{
|
||||
__syncthreads();
|
||||
v[i] = float(_v[t]);
|
||||
gy[i] = float(_gy[t]);
|
||||
__syncthreads();
|
||||
|
||||
const float rr = float(_r[t]);
|
||||
float gk = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < _N_; j++)
|
||||
{
|
||||
float& s = scccc[j];
|
||||
float x = rr * gy[j];
|
||||
|
||||
gk += (u * x + s) * v[j];
|
||||
s = x + s * w;
|
||||
}
|
||||
_gk[t] = F(gk);
|
||||
}
|
||||
|
||||
for (int t = t111 - C; t >= t000; t -= C)
|
||||
{
|
||||
__syncthreads();
|
||||
r[i] = float(_r[t]);
|
||||
k[i] = float(_k[t]);
|
||||
__syncthreads();
|
||||
|
||||
const float gyy = float(_gy[t]);
|
||||
float gv = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < _N_; j++)
|
||||
{
|
||||
float& s = sdddd[j];
|
||||
float x = gyy * r[j];
|
||||
|
||||
gv += (u_[j] * x + s) * k[j];
|
||||
s = x + s * w_[j];
|
||||
}
|
||||
_gv[t] = F(gv);
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
||||
{
|
||||
assert(H*_N_ == C);
|
||||
assert(_N_%4 == 0);
|
||||
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
|
||||
{
|
||||
assert(H*_N_ == C);
|
||||
assert(_N_%4 == 0);
|
||||
kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
|
||||
}
|
||||
22
finetune/lora/v5/cuda/wkv5_op.cpp
vendored
Normal file
22
finetune/lora/v5/cuda/wkv5_op.cpp
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/ATen.h"
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
||||
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
|
||||
|
||||
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||
}
|
||||
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
|
||||
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
|
||||
}
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv5 forward");
|
||||
m.def("backward", &backward, "wkv5 backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv5, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
||||
0
finetune/lora/v5/src/__init__.py
vendored
Normal file
0
finetune/lora/v5/src/__init__.py
vendored
Normal file
303
finetune/lora/v5/src/binidx.py
vendored
Normal file
303
finetune/lora/v5/src/binidx.py
vendored
Normal file
@@ -0,0 +1,303 @@
|
||||
from lib2to3.pgen2 import token
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import shutil
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
|
||||
|
||||
def print_rank_0(*message):
|
||||
pass
|
||||
# """If distributed is initialized print only on rank 0."""
|
||||
# if torch.distributed.is_initialized():
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(*message, flush=True)
|
||||
# else:
|
||||
# print(*message, flush=True)
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
pass
|
||||
# with open(path, "rb") as stream:
|
||||
# while stream.read(100 * 1024 * 1024):
|
||||
# pass
|
||||
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: float,
|
||||
7: np.double,
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
# Write Magic string so we can check the file format then opening it again.
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
# Write version number
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
# Little endian unsigned 8 Bit integer
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
dtype_size = dtype().itemsize
|
||||
address = 0
|
||||
pointers = []
|
||||
|
||||
for size in sizes:
|
||||
pointers.append(address)
|
||||
address += size * dtype_size
|
||||
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
# Little endian unsigned 64 Bit integer
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
# Little endian unsigned 8 Bit integer
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
print_rank_0(" reading sizes...")
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
print_rank_0(" reading pointers...")
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
print_rank_0(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print_rank_0(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(
|
||||
data_file_path(self._path), mode="r", order="C"
|
||||
)
|
||||
print_rank_0(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||
)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
|
||||
def pad(self, idx, length=None):
|
||||
ptr, size = self._index[idx]
|
||||
try:
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||
)
|
||||
except:
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
ptr0, _ = self._index[0]
|
||||
np_array0 = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=self._index.dtype,
|
||||
count=length - size,
|
||||
offset=ptr0,
|
||||
)
|
||||
np_array = np.append(np_array, np_array0)
|
||||
return np_array
|
||||
|
||||
def only(self, idx):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
241
finetune/lora/v5/src/dataset.py
vendored
Normal file
241
finetune/lora/v5/src/dataset.py
vendored
Normal file
@@ -0,0 +1,241 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import json, math, random, os, sys
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from .binidx import MMapIndexedDataset
|
||||
from .utils import MaybeIsPrime
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
if args.data_type == "binidx":
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info(
|
||||
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||
)
|
||||
|
||||
if args.my_pile_version == 1:
|
||||
self.data = MMapIndexedDataset(args.data_file)
|
||||
self.data_size = (
|
||||
len(self.data._bin_buffer) // self.data._index._dtype_size
|
||||
)
|
||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||
elif args.my_pile_version == 2:
|
||||
data_list = (
|
||||
open(args.data_file, "r", encoding="utf-8")
|
||||
.read()
|
||||
.strip()
|
||||
.split("\n")
|
||||
)
|
||||
data_list = [i.strip().split(" ") for i in data_list]
|
||||
self.data = []
|
||||
self.data_size = int(data_list[-1][-1])
|
||||
rank_zero_info(f"Data has {self.data_size} chunks.")
|
||||
for d in data_list:
|
||||
data = MMapIndexedDataset(d[0])
|
||||
data_size = len(data._bin_buffer) // data._index._dtype_size
|
||||
assert (data_size - args.ctx_len) == int(d[1])
|
||||
self.data += [[int(d[-1]), int(d[1]), data]]
|
||||
# rank_zero_info(self.data)
|
||||
|
||||
if args.my_qa_mask > 0:
|
||||
# self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
|
||||
self.data_pile = MMapIndexedDataset(
|
||||
"/fsx/pile_deduped/pile_0.87_deduped_text_document"
|
||||
)
|
||||
self.data_pile_size = (
|
||||
len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
||||
)
|
||||
else:
|
||||
self.data_pile = None
|
||||
self.data_pile_size = 0
|
||||
|
||||
if args.my_pile_stage > 0:
|
||||
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
||||
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||
assert self.samples_per_epoch == 40320
|
||||
rank_zero_info(
|
||||
f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########"
|
||||
)
|
||||
dataset_slot = self.data_size // args.ctx_len
|
||||
if args.my_pile_stage != 4:
|
||||
assert MaybeIsPrime(args.magic_prime)
|
||||
assert args.magic_prime % 3 == 2
|
||||
assert (
|
||||
args.magic_prime / dataset_slot > 0.99
|
||||
and args.magic_prime / dataset_slot <= 1
|
||||
)
|
||||
elif args.data_type == "numpy":
|
||||
self.data = np.load(args.data_file).astype("int")
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info(
|
||||
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||
)
|
||||
self.data_size = len(self.data)
|
||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||
elif args.data_type == "uint16":
|
||||
self.data = (
|
||||
np.fromfile(args.data_file, dtype=np.uint16)
|
||||
.astype("int32")
|
||||
.reshape(-1, args.my_sample_len)
|
||||
)
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info(
|
||||
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||
)
|
||||
self.data_size = self.data.shape[0]
|
||||
rank_zero_info(f"Data has {self.data_size} samples.")
|
||||
else:
|
||||
if args.data_type == "dummy":
|
||||
rank_zero_info("Building dummy data...")
|
||||
self.data = ""
|
||||
for i in range(100000):
|
||||
aa = (i) % 10000
|
||||
bb = (i * i) % 10000
|
||||
cc = aa + bb
|
||||
self.data += f".{aa}+{bb}={cc}."
|
||||
else:
|
||||
self.data = open(args.data_file, "r", encoding=args.data_type).read()
|
||||
rank_zero_info("Building token list...")
|
||||
unique = sorted(list(set(self.data)))
|
||||
self.vocab_size = len(unique)
|
||||
# rank_zero_info()
|
||||
# for u in unique:
|
||||
# print(u, end=' ')
|
||||
# rank_zero_info('\n\n')
|
||||
xx = 0
|
||||
xxObj = {}
|
||||
for u in unique:
|
||||
xxObj[xx] = u
|
||||
xx += 1
|
||||
with open(
|
||||
f"{args.proj_dir}/vocab.json", "w", encoding="utf-8"
|
||||
) as vocab_file:
|
||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||
self.data_size = len(self.data)
|
||||
rank_zero_info(
|
||||
f"Data has {self.data_size} tokens, {self.vocab_size} vocab size."
|
||||
)
|
||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||
|
||||
def __len__(self):
|
||||
return self.args.epoch_steps * self.args.micro_bsz
|
||||
|
||||
def __getitem__(self, idx):
|
||||
args = self.args
|
||||
rank = self.global_rank
|
||||
epoch = self.real_epoch
|
||||
world_size = self.world_size
|
||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
||||
|
||||
if args.data_type == "uint16":
|
||||
i = np.random.randint(0, self.data_size - 1)
|
||||
dix = self.data[i]
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||
else:
|
||||
ctx_len = args.ctx_len
|
||||
req_len = ctx_len + 1
|
||||
magic_prime = args.magic_prime
|
||||
data = self.data
|
||||
|
||||
if args.my_pile_stage > 0:
|
||||
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
|
||||
|
||||
if args.my_qa_mask > 0:
|
||||
ii_orig = ii
|
||||
if ii % 2 == 0:
|
||||
ii = -1
|
||||
data = self.data_pile
|
||||
else:
|
||||
ii = ii // 2
|
||||
if data == self.data_pile:
|
||||
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||
else:
|
||||
if args.my_pile_stage == 4 or ii < args.my_random_steps:
|
||||
# cheat: pick a random spot in dataset
|
||||
if args.my_pile_version == 1:
|
||||
i = np.random.randint(0, self.data_size - req_len)
|
||||
else:
|
||||
i = np.random.randint(0, self.data_size)
|
||||
else:
|
||||
ii = ii - args.my_random_steps
|
||||
factor = (math.sqrt(5) - 1) / 2
|
||||
factor = int(magic_prime * factor)
|
||||
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
|
||||
i = i + args.my_pile_shift
|
||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
|
||||
else:
|
||||
# cheat: pick a random spot in dataset
|
||||
i = np.random.randint(0, self.data_size - req_len)
|
||||
|
||||
if args.data_type == "binidx":
|
||||
if args.my_pile_version == 1:
|
||||
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
||||
else:
|
||||
# self.data : cutoff, chunk_count, data
|
||||
for j in range(len(data)):
|
||||
if i < data[j][0]:
|
||||
ii = i
|
||||
i = (i - (data[j - 1][0] if j > 0 else 0)) % data[j][1]
|
||||
dix = (
|
||||
data[j][2]
|
||||
.get(idx=0, offset=i, length=req_len)
|
||||
.astype(int)
|
||||
)
|
||||
# print(ii, j, i)
|
||||
break
|
||||
elif args.data_type == "numpy":
|
||||
dix = data[i : i + req_len]
|
||||
else:
|
||||
dix = [self.stoi[s] for s in data[i : i + req_len]]
|
||||
|
||||
if args.my_qa_mask == 1:
|
||||
if data == self.data_pile:
|
||||
z = [1] * ctx_len
|
||||
else:
|
||||
z = [0] * ctx_len
|
||||
z_sum = 0
|
||||
isGood = False
|
||||
for i in range(3, ctx_len):
|
||||
if (
|
||||
dix[i] == 27
|
||||
and dix[i - 1] == 34
|
||||
and dix[i - 2] == 187
|
||||
and dix[i - 3] == 187
|
||||
):
|
||||
isGood = True
|
||||
if dix[i] == 0:
|
||||
isGood = False
|
||||
if isGood:
|
||||
z[i] = 1
|
||||
z_sum += 1
|
||||
if z_sum == 0:
|
||||
z = [1] * ctx_len
|
||||
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||
dix = self.data_pile.get(
|
||||
idx=0, offset=i, length=req_len
|
||||
).astype(int)
|
||||
z = torch.tensor(z, dtype=torch.bfloat16)
|
||||
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||
|
||||
# if ii_orig < 50:
|
||||
# # if rank == 1:
|
||||
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
|
||||
# else:
|
||||
# exit(0)
|
||||
|
||||
if args.my_qa_mask == 1:
|
||||
return x, y, z
|
||||
|
||||
return x, y
|
||||
819
finetune/lora/v5/src/model.py
vendored
Normal file
819
finetune/lora/v5/src/model.py
vendored
Normal file
@@ -0,0 +1,819 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
import functools
|
||||
import os, math, gc, importlib
|
||||
import torch
|
||||
|
||||
# torch._C._jit_set_profiling_executor(True)
|
||||
# torch._C._jit_set_profiling_mode(True)
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
from torch.nn import functional as F
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||
|
||||
if importlib.util.find_spec("deepspeed"):
|
||||
import deepspeed
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
||||
|
||||
|
||||
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
|
||||
|
||||
# lora-config
|
||||
LORA_CONFIG = {
|
||||
"r": 0,
|
||||
"alpha": 0,
|
||||
"dropout": 0,
|
||||
"parts": {"att", "ln", "time"},
|
||||
}
|
||||
|
||||
try:
|
||||
print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"])
|
||||
except:
|
||||
os.environ["RWKV_MY_TESTING"] = ""
|
||||
|
||||
|
||||
def __nop(ob):
|
||||
return ob
|
||||
|
||||
|
||||
MyModule = nn.Module
|
||||
MyFunction = __nop
|
||||
if os.environ["RWKV_JIT_ON"] == "1":
|
||||
MyModule = torch.jit.ScriptModule
|
||||
MyFunction = torch.jit.script_method
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# CUDA Kernel
|
||||
########################################################################################################
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"])
|
||||
wkv5_cuda = load(
|
||||
name="wkv5",
|
||||
sources=[
|
||||
"finetune/lora/v5/cuda/wkv5_op.cpp",
|
||||
f"finetune/lora/v5/cuda/wkv5_cuda.cu",
|
||||
],
|
||||
verbose=True,
|
||||
extra_cuda_cflags=[
|
||||
"-res-usage",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"-Xptxas -O3",
|
||||
"--extra-device-vectorization",
|
||||
f"-D_N_={HEAD_SIZE}",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class WKV_5(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, H, r, k, v, w, u):
|
||||
with torch.no_grad():
|
||||
assert r.dtype == torch.bfloat16
|
||||
assert k.dtype == torch.bfloat16
|
||||
assert v.dtype == torch.bfloat16
|
||||
assert w.dtype == torch.bfloat16
|
||||
assert u.dtype == torch.bfloat16
|
||||
assert HEAD_SIZE == C // H
|
||||
ctx.B = B
|
||||
ctx.T = T
|
||||
ctx.C = C
|
||||
ctx.H = H
|
||||
assert r.is_contiguous()
|
||||
assert k.is_contiguous()
|
||||
assert v.is_contiguous()
|
||||
assert w.is_contiguous()
|
||||
assert u.is_contiguous()
|
||||
ew = (-torch.exp(w.float())).contiguous()
|
||||
eew = (torch.exp(ew)).contiguous()
|
||||
ctx.save_for_backward(r, k, v, eew, ew, u)
|
||||
y = torch.empty(
|
||||
(B, T, C),
|
||||
device=r.device,
|
||||
dtype=torch.bfloat16,
|
||||
memory_format=torch.contiguous_format,
|
||||
) # .uniform_(-1, 1)
|
||||
wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
with torch.no_grad():
|
||||
assert gy.dtype == torch.bfloat16
|
||||
B = ctx.B
|
||||
T = ctx.T
|
||||
C = ctx.C
|
||||
H = ctx.H
|
||||
assert gy.is_contiguous()
|
||||
r, k, v, eew, ew, u = ctx.saved_tensors
|
||||
gr = torch.empty(
|
||||
(B, T, C),
|
||||
device=gy.device,
|
||||
requires_grad=False,
|
||||
dtype=torch.bfloat16,
|
||||
memory_format=torch.contiguous_format,
|
||||
) # .uniform_(-1, 1)
|
||||
gk = torch.empty(
|
||||
(B, T, C),
|
||||
device=gy.device,
|
||||
requires_grad=False,
|
||||
dtype=torch.bfloat16,
|
||||
memory_format=torch.contiguous_format,
|
||||
) # .uniform_(-1, 1)
|
||||
gv = torch.empty(
|
||||
(B, T, C),
|
||||
device=gy.device,
|
||||
requires_grad=False,
|
||||
dtype=torch.bfloat16,
|
||||
memory_format=torch.contiguous_format,
|
||||
) # .uniform_(-1, 1)
|
||||
gw = torch.empty(
|
||||
(B, C),
|
||||
device=gy.device,
|
||||
requires_grad=False,
|
||||
dtype=torch.bfloat16,
|
||||
memory_format=torch.contiguous_format,
|
||||
) # .uniform_(-1, 1)
|
||||
gu = torch.empty(
|
||||
(B, C),
|
||||
device=gy.device,
|
||||
requires_grad=False,
|
||||
dtype=torch.bfloat16,
|
||||
memory_format=torch.contiguous_format,
|
||||
) # .uniform_(-1, 1)
|
||||
wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
|
||||
gw = torch.sum(gw, 0).view(H, C // H)
|
||||
gu = torch.sum(gu, 0).view(H, C // H)
|
||||
return (None, None, None, None, gr, gk, gv, gw, gu)
|
||||
|
||||
|
||||
def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
|
||||
return WKV_5.apply(B, T, C, H, r, k, v, w, u)
|
||||
|
||||
|
||||
#################################################################
|
||||
class LoraLinear(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||
assert bias == False, "Biased LoraLinear not supported"
|
||||
|
||||
r, alpha, dropout = (
|
||||
LORA_CONFIG["r"],
|
||||
LORA_CONFIG["alpha"],
|
||||
LORA_CONFIG["dropout"],
|
||||
)
|
||||
self.lora_A = nn.Parameter(torch.empty(r, in_features))
|
||||
self.lora_B = nn.Parameter(torch.empty(out_features, r))
|
||||
self.lora_dropout = nn.Dropout(dropout)
|
||||
self.scaling = alpha / r
|
||||
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def forward(self, x):
|
||||
return F.linear(x, self.weight) + self.scaling * F.linear(
|
||||
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
|
||||
)
|
||||
|
||||
|
||||
@functools.wraps(LoraLinear)
|
||||
def make_linear_att(*args, **kwargs):
|
||||
if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||||
return LoraLinear(*args, **kwargs)
|
||||
else:
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
@functools.wraps(LoraLinear)
|
||||
def make_linear_ffn(*args, **kwargs):
|
||||
if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||||
return LoraLinear(*args, **kwargs)
|
||||
else:
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class RWKV_TimeMix_RWKV5(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.head_size = args.head_size_a
|
||||
assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
|
||||
self.n_head = args.dim_att // self.head_size
|
||||
assert args.dim_att % self.n_head == 0
|
||||
self.head_size_divisor = args.head_size_divisor
|
||||
|
||||
with torch.no_grad():
|
||||
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||||
ddd = torch.ones(1, 1, args.n_embd)
|
||||
for i in range(args.n_embd):
|
||||
ddd[0, 0, i] = i / args.n_embd
|
||||
|
||||
# fancy time_mix
|
||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_v = nn.Parameter(
|
||||
torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||
)
|
||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||
self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||
|
||||
# fancy time_decay
|
||||
decay_speed = torch.ones(args.dim_att)
|
||||
for n in range(args.dim_att):
|
||||
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
|
||||
0.7 + 1.3 * ratio_0_to_1
|
||||
)
|
||||
self.time_decay = nn.Parameter(
|
||||
decay_speed.reshape(self.n_head, self.head_size)
|
||||
)
|
||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||||
|
||||
tmp = torch.zeros(args.dim_att)
|
||||
for n in range(args.dim_att):
|
||||
zigzag = ((n + 1) % 3 - 1) * 0.1
|
||||
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
|
||||
|
||||
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
|
||||
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
||||
self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
|
||||
|
||||
@MyFunction
|
||||
def jit_func(self, x):
|
||||
B, T, C = x.size()
|
||||
|
||||
xx = self.time_shift(
|
||||
x
|
||||
) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
|
||||
|
||||
r = self.receptance(xr)
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
g = F.silu(self.gate(xg))
|
||||
|
||||
return r, k, v, g
|
||||
|
||||
@MyFunction
|
||||
def jit_func_2(self, x, g):
|
||||
B, T, C = x.size()
|
||||
x = x.view(B * T, C)
|
||||
x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
|
||||
x = self.output(x * g)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size()
|
||||
H = self.n_head
|
||||
r, k, v, g = self.jit_func(x)
|
||||
x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
|
||||
|
||||
return self.jit_func_2(x, g)
|
||||
|
||||
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class RWKV_ChannelMix(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad(): # fancy init of time_mix
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||||
ddd = torch.ones(1, 1, args.n_embd)
|
||||
for i in range(args.n_embd):
|
||||
ddd[0, 0, i] = i / args.n_embd
|
||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
|
||||
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
|
||||
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
|
||||
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
|
||||
|
||||
@MyFunction
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
k = self.key(xk)
|
||||
k = torch.relu(k) ** 2
|
||||
kv = self.value(k)
|
||||
return torch.sigmoid(self.receptance(xr)) * kv
|
||||
|
||||
|
||||
class MishGLU(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad():
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
|
||||
|
||||
x = torch.ones(1, 1, args.n_embd)
|
||||
for i in range(args.n_embd):
|
||||
x[0, 0, i] = i / args.n_embd
|
||||
|
||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||||
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||||
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
||||
|
||||
@MyFunction
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x)
|
||||
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
a = self.aa(xa)
|
||||
b = self.bb(xb)
|
||||
return self.value(a * F.mish(b))
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# The RWKV Model with our blocks
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.ln1 = nn.LayerNorm(args.n_embd)
|
||||
self.ln2 = nn.LayerNorm(args.n_embd)
|
||||
|
||||
if self.layer_id == 0:
|
||||
self.ln0 = nn.LayerNorm(args.n_embd)
|
||||
if args.my_pos_emb > 0:
|
||||
self.pos_emb_x = nn.Parameter(
|
||||
torch.zeros((1, args.my_pos_emb, args.n_embd))
|
||||
)
|
||||
self.pos_emb_y = nn.Parameter(
|
||||
torch.zeros((args.my_pos_emb, 1, args.n_embd))
|
||||
)
|
||||
|
||||
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
||||
self.ffnPre = RWKV_ChannelMix(args, 0)
|
||||
else:
|
||||
self.att = RWKV_TimeMix_RWKV5(args, layer_id)
|
||||
|
||||
if "g" in os.environ["RWKV_MY_TESTING"]:
|
||||
self.ffn = MishGLU(args, layer_id)
|
||||
else:
|
||||
self.ffn = RWKV_ChannelMix(args, layer_id)
|
||||
|
||||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||||
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
||||
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
||||
self.register_buffer(
|
||||
"tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||
)
|
||||
|
||||
if args.dropout > 0:
|
||||
self.drop0 = nn.Dropout(p=args.dropout)
|
||||
self.drop1 = nn.Dropout(p=args.dropout)
|
||||
|
||||
def forward(self, x, x_emb=None):
|
||||
args = self.args
|
||||
B, T, C = x.size()
|
||||
if self.layer_id == 0:
|
||||
x = self.ln0(x)
|
||||
if args.my_pos_emb > 0:
|
||||
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :]
|
||||
x = x + pos_emb
|
||||
|
||||
if self.args.dropout == 0:
|
||||
if self.layer_id == 0 and args.pre_ffn > 0:
|
||||
x = x + self.ffnPre(self.ln1(x))
|
||||
else:
|
||||
x = x + self.att(self.ln1(x))
|
||||
x = x + self.ffn(self.ln2(x))
|
||||
else:
|
||||
if self.layer_id == 0 and args.pre_ffn > 0:
|
||||
x = self.drop0(x + self.ffnPre(self.ln1(x)))
|
||||
else:
|
||||
x = self.drop0(x + self.att(self.ln1(x)))
|
||||
x = self.drop1(x + self.ffn(self.ln2(x)))
|
||||
|
||||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||||
xx = self.tiny_ln(x)
|
||||
q = self.tiny_q(xx)[:, :T, :]
|
||||
k = self.tiny_k(xx)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
||||
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
||||
x = x + c @ self.tiny_v(x_emb)
|
||||
return x
|
||||
|
||||
|
||||
class L2Wrap(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, loss, y):
|
||||
ctx.save_for_backward(y)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
y = ctx.saved_tensors[0]
|
||||
# to encourage the logits to be close to 0
|
||||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||
gy = torch.zeros_like(y)
|
||||
gy.scatter_(-1, ids, maxx * factor)
|
||||
return (grad_output, gy)
|
||||
|
||||
|
||||
class RWKV(pl.LightningModule):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
if not hasattr(args, "dim_att"):
|
||||
args.dim_att = args.n_embd
|
||||
if not hasattr(args, "dim_ffn"):
|
||||
args.dim_ffn = args.n_embd * 4
|
||||
if not hasattr(args, "tiny_att_layer"):
|
||||
args.tiny_att_layer = -1
|
||||
if not hasattr(args, "tiny_att_dim"):
|
||||
args.tiny_att_dim = -1
|
||||
assert args.n_embd % 32 == 0
|
||||
assert args.dim_att % 32 == 0
|
||||
assert args.dim_ffn % 32 == 0
|
||||
|
||||
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
||||
|
||||
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
||||
|
||||
self.ln_out = nn.LayerNorm(args.n_embd)
|
||||
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||
|
||||
if args.head_qk > 0:
|
||||
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.register_buffer(
|
||||
"copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||
)
|
||||
if args.dropout > 0:
|
||||
self.drop0 = nn.Dropout(p=args.dropout)
|
||||
|
||||
def configure_optimizers(self):
|
||||
args = self.args
|
||||
|
||||
lr_decay = set()
|
||||
lr_1x = set()
|
||||
lr_2x = set()
|
||||
lr_3x = set()
|
||||
for n, p in self.named_parameters():
|
||||
if ("time_mix" in n) and (args.layerwise_lr > 0):
|
||||
if args.my_pile_stage == 2:
|
||||
lr_2x.add(n)
|
||||
else:
|
||||
lr_1x.add(n)
|
||||
elif ("time_decay" in n) and (args.layerwise_lr > 0):
|
||||
if args.my_pile_stage == 2:
|
||||
lr_3x.add(n)
|
||||
else:
|
||||
lr_2x.add(n)
|
||||
elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
|
||||
if args.my_pile_stage == 2:
|
||||
lr_2x.add(n)
|
||||
else:
|
||||
lr_1x.add(n)
|
||||
elif ("time_first" in n) and (args.layerwise_lr > 0):
|
||||
lr_3x.add(n)
|
||||
elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
|
||||
lr_decay.add(n)
|
||||
else:
|
||||
lr_1x.add(n)
|
||||
|
||||
lr_decay = sorted(list(lr_decay))
|
||||
lr_1x = sorted(list(lr_1x))
|
||||
lr_2x = sorted(list(lr_2x))
|
||||
lr_3x = sorted(list(lr_3x))
|
||||
# print('decay', lr_decay)
|
||||
# print('1x', lr_1x)
|
||||
# print('2x', lr_2x)
|
||||
# print('3x', lr_3x)
|
||||
param_dict = {n: p for n, p in self.named_parameters()}
|
||||
|
||||
if args.layerwise_lr > 0:
|
||||
if args.my_pile_stage == 2:
|
||||
optim_groups = [
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_1x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 1.0,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_2x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 5.0,
|
||||
}, # test: 2e-3 / args.lr_init},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_3x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 5.0,
|
||||
}, # test: 3e-3 / args.lr_init},
|
||||
]
|
||||
else:
|
||||
optim_groups = [
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_1x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 1.0,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_2x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 2.0,
|
||||
},
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_3x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 3.0,
|
||||
},
|
||||
]
|
||||
else:
|
||||
optim_groups = [
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_1x],
|
||||
"weight_decay": 0.0,
|
||||
"my_lr_scale": 1.0,
|
||||
}
|
||||
]
|
||||
|
||||
if args.weight_decay > 0:
|
||||
optim_groups += [
|
||||
{
|
||||
"params": [param_dict[n] for n in lr_decay],
|
||||
"weight_decay": args.weight_decay,
|
||||
"my_lr_scale": 1.0,
|
||||
}
|
||||
]
|
||||
if self.deepspeed_offload:
|
||||
return DeepSpeedCPUAdam(
|
||||
optim_groups,
|
||||
lr=self.args.lr_init,
|
||||
betas=self.args.betas,
|
||||
eps=self.args.adam_eps,
|
||||
bias_correction=True,
|
||||
adamw_mode=True,
|
||||
amsgrad=False,
|
||||
)
|
||||
return FusedAdam(
|
||||
optim_groups,
|
||||
lr=self.args.lr_init,
|
||||
betas=self.args.betas,
|
||||
eps=self.args.adam_eps,
|
||||
bias_correction=True,
|
||||
adam_w_mode=True,
|
||||
amsgrad=False,
|
||||
)
|
||||
else:
|
||||
if self.deepspeed_offload:
|
||||
return DeepSpeedCPUAdam(
|
||||
optim_groups,
|
||||
lr=self.args.lr_init,
|
||||
betas=self.args.betas,
|
||||
eps=self.args.adam_eps,
|
||||
bias_correction=True,
|
||||
adamw_mode=False,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
)
|
||||
return FusedAdam(
|
||||
optim_groups,
|
||||
lr=self.args.lr_init,
|
||||
betas=self.args.betas,
|
||||
eps=self.args.adam_eps,
|
||||
bias_correction=True,
|
||||
adam_w_mode=False,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
)
|
||||
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
||||
|
||||
@property
|
||||
def deepspeed_offload(self) -> bool:
|
||||
strategy = self.trainer.strategy
|
||||
if isinstance(strategy, DeepSpeedStrategy):
|
||||
cfg = strategy.config["zero_optimization"]
|
||||
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
||||
return False
|
||||
|
||||
def forward(self, idx):
|
||||
args = self.args
|
||||
B, T = idx.size()
|
||||
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
||||
|
||||
x = self.emb(idx)
|
||||
x_emb = x
|
||||
|
||||
if args.dropout > 0:
|
||||
x = self.drop0(x)
|
||||
if args.tiny_att_dim > 0:
|
||||
for block in self.blocks:
|
||||
if args.grad_cp == 1:
|
||||
if args.lora:
|
||||
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||||
else:
|
||||
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
||||
else:
|
||||
x = block(x, x_emb)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
if args.grad_cp == 1:
|
||||
if args.lora:
|
||||
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||||
else:
|
||||
x = deepspeed.checkpointing.checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_out(x)
|
||||
|
||||
if args.head_qk > 0:
|
||||
q = self.head_q(x)[:, :T, :]
|
||||
k = self.head_k(x)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
||||
|
||||
x = self.head(x) + c
|
||||
else:
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
args = self.args
|
||||
if args.my_qa_mask != 1:
|
||||
idx, targets = batch
|
||||
logits = self(idx)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
# if '0' in os.environ["RWKV_MY_TESTING"]:
|
||||
# print('logits', logits)
|
||||
# torch.set_printoptions(threshold=10000)
|
||||
# print('idx', idx)
|
||||
# exit(0)
|
||||
else:
|
||||
idx, targets, mask = batch
|
||||
mask = mask.view(-1)
|
||||
sum_mask = torch.sum(mask).item()
|
||||
# if sum_mask == 0:
|
||||
# return torch.tensor([0.0], requires_grad=True)
|
||||
|
||||
logits = self(idx)
|
||||
if sum_mask == mask.shape[0]:
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), targets.view(-1)
|
||||
)
|
||||
# print('rank', self.global_rank, 'loss', loss.item())
|
||||
else:
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none"
|
||||
)
|
||||
# loss_raw = loss
|
||||
loss = torch.sum(loss * mask) / sum_mask
|
||||
|
||||
# torch.set_printoptions(threshold=10000)
|
||||
# if True: #self.global_rank == 1:
|
||||
# tmp = ''
|
||||
# sss = 0
|
||||
# ccc = 0
|
||||
# for i in range(mask.shape[0]):
|
||||
# if mask[i] > 0:
|
||||
# tmp += str(idx.view(-1)[i].item()) + ','
|
||||
# sss += loss_raw.view(-1)[i].float().item()
|
||||
# ccc += 1
|
||||
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
||||
return L2Wrap.apply(loss, logits)
|
||||
|
||||
def training_step_end(self, batch_parts):
|
||||
if pl.__version__[0] != "2":
|
||||
all = self.all_gather(batch_parts)
|
||||
if self.trainer.is_global_zero:
|
||||
self.trainer.my_loss_all = all
|
||||
|
||||
def generate_init_weight(self):
|
||||
print(
|
||||
f"""
|
||||
############################################################################
|
||||
#
|
||||
# Init model weight (slow for large models)...
|
||||
#
|
||||
############################################################################
|
||||
"""
|
||||
)
|
||||
m = {}
|
||||
for n in self.state_dict():
|
||||
p = self.state_dict()[n]
|
||||
shape = p.shape
|
||||
|
||||
gain = 1.0
|
||||
scale = 1.0
|
||||
if (
|
||||
"ln_" in n
|
||||
or ".ln" in n
|
||||
or "time_" in n
|
||||
or "_mask" in n
|
||||
or "pos_emb" in n
|
||||
or ".mask." in n
|
||||
):
|
||||
if "ln_x.weight" in n:
|
||||
layer_scale = (1 + int(n.split(".")[1])) / self.args.n_layer
|
||||
m[n] = (p * 0.0) + (layer_scale**0.7)
|
||||
else:
|
||||
m[n] = p
|
||||
else:
|
||||
if n == "emb.weight":
|
||||
scale = -1 * self.args.lr_init
|
||||
else:
|
||||
if shape[0] > shape[1]:
|
||||
gain = math.sqrt(shape[0] / shape[1])
|
||||
|
||||
zero = [
|
||||
".att.output.",
|
||||
".ffn.value.",
|
||||
".ffn.receptance.",
|
||||
".ffnPre.value.",
|
||||
".ffnPre.receptance.",
|
||||
"head_q.",
|
||||
".oo.",
|
||||
".rr.",
|
||||
]
|
||||
|
||||
for kk in zero:
|
||||
if kk in n:
|
||||
scale = 0
|
||||
if n == "head.weight":
|
||||
scale = 0.5
|
||||
if "head_k." in n:
|
||||
scale = 0.1
|
||||
if "head_q." in n:
|
||||
scale = 0
|
||||
|
||||
print(
|
||||
f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}"
|
||||
)
|
||||
|
||||
if self.args.accelerator.upper() == "GPU":
|
||||
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
||||
else:
|
||||
m[n] = torch.empty((shape[0], shape[1]))
|
||||
|
||||
if scale == 0:
|
||||
nn.init.zeros_(m[n])
|
||||
elif scale < 0:
|
||||
nn.init.uniform_(m[n], a=scale, b=-scale)
|
||||
else:
|
||||
nn.init.orthogonal_(m[n], gain=gain * scale)
|
||||
|
||||
m[n] = m[n].cpu()
|
||||
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
m[n] = m[n].half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
m[n] = m[n].bfloat16()
|
||||
|
||||
# if n == "emb.weight":
|
||||
# print(m[n])
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return m
|
||||
310
finetune/lora/v5/src/trainer.py
vendored
Normal file
310
finetune/lora/v5/src/trainer.py
vendored
Normal file
@@ -0,0 +1,310 @@
|
||||
import os, math, time, datetime, subprocess
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from .model import LORA_CONFIG
|
||||
|
||||
|
||||
def my_save(args, trainer, dd, ff):
|
||||
if "14b-run1" in ff:
|
||||
fn = ff.split("/")[-1]
|
||||
fff = "/dev/shm/" + fn
|
||||
torch.save(dd, fff)
|
||||
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
||||
elif ("world/14b" in ff) or ("world/7b" in ff):
|
||||
aa = ff.split("/")[1]
|
||||
fn = ff.split("/")[-1]
|
||||
fff = f"/dev/shm/{aa}-{fn}"
|
||||
torch.save(dd, fff)
|
||||
subprocess.Popen(
|
||||
f" aws s3 mv {fff} s3://rwkv-world/{aa}-{fn} --quiet", shell=True
|
||||
)
|
||||
else:
|
||||
if "deepspeed_stage_3" in args.strategy:
|
||||
trainer.save_checkpoint(ff, weights_only=True)
|
||||
else:
|
||||
torch.save(dd, ff)
|
||||
|
||||
|
||||
class train_callback(pl.Callback):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
args = self.args
|
||||
# if args.cuda_cleanup > 0:
|
||||
# torch.cuda.empty_cache()
|
||||
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||
|
||||
# LR schedule
|
||||
w_step = args.warmup_steps
|
||||
if args.lr_final == args.lr_init or args.epoch_count == 0:
|
||||
lr = args.lr_init
|
||||
else:
|
||||
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
|
||||
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
|
||||
progress = (decay_step - w_step + 1) / (decay_total - w_step)
|
||||
progress = min(1, max(0, progress))
|
||||
|
||||
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
||||
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
||||
else: # exp decay
|
||||
lr = args.lr_init * math.exp(
|
||||
math.log(args.lr_final / args.lr_init) * pow(progress, 1)
|
||||
)
|
||||
# if trainer.is_global_zero:
|
||||
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
|
||||
|
||||
if args.my_exit_tokens != 0: # cosine decay
|
||||
real_tokens = real_step * args.ctx_len * args.real_bsz
|
||||
warmup_tokens = w_step * args.ctx_len * args.real_bsz
|
||||
progress = (real_tokens - warmup_tokens) / (
|
||||
abs(args.my_exit_tokens) - warmup_tokens
|
||||
)
|
||||
progress = max(0, min(1, progress))
|
||||
lr_final_factor = args.lr_final / args.lr_init
|
||||
lr_mult = (0.5 + lr_final_factor / 2) + (
|
||||
0.5 - lr_final_factor / 2
|
||||
) * math.cos(math.pi * progress)
|
||||
if args.my_exit_tokens > 0:
|
||||
lr = args.lr_init * lr_mult
|
||||
else:
|
||||
lr = (lr + args.lr_init * lr_mult) / 2
|
||||
if progress >= 1:
|
||||
if (trainer.is_global_zero) or ("deepspeed_stage_3" in args.strategy):
|
||||
my_save(
|
||||
args,
|
||||
trainer,
|
||||
pl_module.state_dict(),
|
||||
f"{args.proj_dir}/rwkv-final.pth",
|
||||
)
|
||||
exit(0)
|
||||
if trainer.global_step < w_step:
|
||||
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
||||
|
||||
if args.weight_decay_final > 0:
|
||||
wd_now = args.weight_decay * math.exp(
|
||||
math.log(args.weight_decay_final / args.weight_decay) * progress
|
||||
)
|
||||
else:
|
||||
wd_now = args.weight_decay
|
||||
|
||||
for param_group in trainer.optimizers[0].param_groups:
|
||||
if param_group["weight_decay"] > 0:
|
||||
param_group["weight_decay"] = wd_now
|
||||
if args.layerwise_lr > 0:
|
||||
param_group["lr"] = lr * param_group["my_lr_scale"]
|
||||
# print(param_group["lr"], param_group["my_lr_scale"])
|
||||
else:
|
||||
param_group["lr"] = lr
|
||||
|
||||
trainer.my_lr = lr
|
||||
trainer.my_wd = wd_now
|
||||
# rank_zero_info(f"{real_step} {lr}")
|
||||
|
||||
if trainer.global_step == 0:
|
||||
if trainer.is_global_zero: # logging
|
||||
trainer.my_loss_sum = 0
|
||||
trainer.my_loss_count = 0
|
||||
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
||||
trainer.my_log.write(
|
||||
f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n"
|
||||
)
|
||||
try:
|
||||
print(f"\n{trainer.strategy.config}\n")
|
||||
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
||||
except:
|
||||
pass
|
||||
trainer.my_log.flush()
|
||||
if len(args.wandb) > 0:
|
||||
print("Login to wandb...")
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project=args.wandb,
|
||||
name=args.run_name + " " + args.my_timestamp,
|
||||
config=args,
|
||||
save_code=False,
|
||||
)
|
||||
trainer.my_wandb = wandb
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
args = self.args
|
||||
token_per_step = args.ctx_len * args.real_bsz
|
||||
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||
if trainer.is_global_zero: # logging
|
||||
t_now = time.time_ns()
|
||||
kt_s = 0
|
||||
try:
|
||||
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
||||
kt_s = token_per_step / t_cost / 1000
|
||||
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
||||
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
|
||||
except:
|
||||
pass
|
||||
trainer.my_time_ns = t_now
|
||||
if pl.__version__[0] == "2":
|
||||
trainer.my_loss = outputs["loss"]
|
||||
else:
|
||||
trainer.my_loss = trainer.my_loss_all.float().mean().item()
|
||||
trainer.my_loss_sum += trainer.my_loss
|
||||
trainer.my_loss_count += 1
|
||||
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
|
||||
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
|
||||
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
||||
# self.log("s", real_step, prog_bar=True, on_step=True)
|
||||
|
||||
if len(args.wandb) > 0:
|
||||
lll = {
|
||||
"loss": trainer.my_loss,
|
||||
"lr": trainer.my_lr,
|
||||
"wd": trainer.my_wd,
|
||||
"Gtokens": real_step * token_per_step / 1e9,
|
||||
}
|
||||
if kt_s > 0:
|
||||
lll["kt/s"] = kt_s
|
||||
trainer.my_wandb.log(lll, step=int(real_step))
|
||||
if (trainer.is_global_zero) or (
|
||||
"deepspeed_stage_3" in args.strategy
|
||||
): # save pth
|
||||
if args.magic_prime > 0:
|
||||
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
||||
if int(real_step) == int(
|
||||
args.magic_prime * expand_factor // args.real_bsz
|
||||
) - 1 + int(args.my_random_steps):
|
||||
to_save_dict = pl_module.state_dict()
|
||||
my_save(
|
||||
args,
|
||||
trainer,
|
||||
to_save_dict,
|
||||
f"{args.proj_dir}/rwkv-final.pth",
|
||||
)
|
||||
# if args.batch_save==batch_idx :
|
||||
# to_save_dict = pl_module.state_dict()
|
||||
# for name, state in to_save_dict.items():
|
||||
# if 'img' in name:
|
||||
# to_save_dict[name] = state
|
||||
# try:
|
||||
# my_save(
|
||||
# args, trainer,
|
||||
# to_save_dict,
|
||||
# f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}-{batch_idx}.pth",
|
||||
# )
|
||||
# except Exception as e:
|
||||
# print('Error\n\n', e, '\n\n')
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
args = self.args
|
||||
if pl.__version__[0] == "2":
|
||||
dataset = trainer.train_dataloader.dataset
|
||||
else:
|
||||
dataset = trainer.train_dataloader.dataset.datasets
|
||||
assert "MyDataset" in str(dataset)
|
||||
dataset.global_rank = trainer.global_rank
|
||||
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
|
||||
dataset.world_size = trainer.world_size
|
||||
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
args = self.args
|
||||
to_save_dict = {}
|
||||
if (trainer.is_global_zero) or (
|
||||
"deepspeed_stage_3" in args.strategy
|
||||
): # save pth
|
||||
if (
|
||||
args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0
|
||||
) or (trainer.current_epoch == args.epoch_count - 1):
|
||||
if args.data_type == "wds_img":
|
||||
raw_dict = pl_module.state_dict()
|
||||
for k in raw_dict:
|
||||
if k.startswith("encoder.") or k.startswith("decoder."):
|
||||
to_save_dict[k] = raw_dict[k]
|
||||
else:
|
||||
to_save_dict = pl_module.state_dict()
|
||||
|
||||
if args.data_type == "img" and not args.lora:
|
||||
for name, state in to_save_dict.items():
|
||||
if "img" in name:
|
||||
to_save_dict[name] = state
|
||||
|
||||
if args.lora:
|
||||
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||
lora_dict = {}
|
||||
for name, state in to_save_dict.items():
|
||||
if "img" in name:
|
||||
lora_dict[name] = state
|
||||
if (
|
||||
".lora_" in name
|
||||
or (enable_time_finetune and ".time_" in name)
|
||||
or (enable_ln_finetune and ".ln" in name)
|
||||
):
|
||||
lora_dict[name] = state
|
||||
to_save_dict = lora_dict
|
||||
|
||||
try:
|
||||
my_save(
|
||||
args,
|
||||
trainer,
|
||||
to_save_dict,
|
||||
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error\n\n", e, "\n\n")
|
||||
|
||||
if trainer.is_global_zero: # logging
|
||||
trainer.my_log.write(
|
||||
f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n"
|
||||
)
|
||||
trainer.my_log.flush()
|
||||
|
||||
trainer.my_loss_sum = 0
|
||||
trainer.my_loss_count = 0
|
||||
if (args.epoch_begin + trainer.current_epoch) >= args.my_exit:
|
||||
exit(0)
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def generate_init_weight(model, init_weight_name):
|
||||
mm = model.generate_init_weight()
|
||||
|
||||
if model.args.my_pile_stage == 1:
|
||||
if len(model.args.load_model) > 0:
|
||||
print(f"Combine weights from {model.args.load_model}...")
|
||||
load_dict = torch.load(model.args.load_model, map_location="cpu")
|
||||
for k in load_dict:
|
||||
try:
|
||||
assert k in mm
|
||||
except:
|
||||
print("missing", k)
|
||||
exit(0)
|
||||
src = load_dict[k]
|
||||
try:
|
||||
mm[k] = src.reshape(mm[k].shape)
|
||||
except:
|
||||
tmp = mm[k].squeeze().clone()
|
||||
print(k, src.shape, "-->", mm[k].shape)
|
||||
ss = src.shape[0]
|
||||
dd = tmp.shape[0]
|
||||
for i in range(dd):
|
||||
pos = i / dd * ss
|
||||
if pos >= ss - 1:
|
||||
tmp[i] = src[ss - 1]
|
||||
else:
|
||||
p0 = int(math.floor(pos))
|
||||
ii = pos - p0
|
||||
tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii)
|
||||
mm[k] = tmp.reshape(mm[k].shape)
|
||||
sss = src.squeeze().float().cpu().numpy()
|
||||
print(sss[:10], "...", sss[-10:])
|
||||
mmm = mm[k].squeeze().float().cpu().numpy()
|
||||
print(mmm[:10], "...", mmm[-10:])
|
||||
|
||||
print(f"Save to {init_weight_name}...")
|
||||
torch.save(mm, init_weight_name)
|
||||
|
||||
if model.args.my_pile_stage == 1:
|
||||
print("Done. Now go for stage 2.")
|
||||
exit(0)
|
||||
139
finetune/lora/v5/src/utils.py
vendored
Normal file
139
finetune/lora/v5/src/utils.py
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
import json, time, random, os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
time_slot = {}
|
||||
time_ref = time.time_ns()
|
||||
|
||||
|
||||
def record_time(name):
|
||||
if name not in time_slot:
|
||||
time_slot[name] = 1e20
|
||||
tt = (time.time_ns() - time_ref) / 1e9
|
||||
if tt < time_slot[name]:
|
||||
time_slot[name] = tt
|
||||
|
||||
|
||||
class TOKENIZER:
|
||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"):
|
||||
if "list" in str(type(WORD_NAME)):
|
||||
self.charMode = False
|
||||
if WORD_NAME[0] == WORD_NAME[1]:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
||||
else:
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
||||
self.vocab_size = len(self.tokenizer)
|
||||
else:
|
||||
self.charMode = True
|
||||
with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file:
|
||||
self.word_table = json.load(result_file)
|
||||
|
||||
self.vocab_size = len(self.word_table)
|
||||
|
||||
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
||||
self.itos = {int(k): v for k, v in self.word_table.items()}
|
||||
|
||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||
|
||||
def refine_context(self, context):
|
||||
context = context.strip().split("\n")
|
||||
for c in range(len(context)):
|
||||
context[c] = context[c].strip().strip("\u3000").strip("\r")
|
||||
context = list(filter(lambda c: c != "", context))
|
||||
context = "\n" + ("\n".join(context)).strip()
|
||||
if context == "":
|
||||
context = "\n"
|
||||
return context
|
||||
|
||||
def sample_logits(
|
||||
self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None
|
||||
):
|
||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||
lastChar = int(x[-1])
|
||||
|
||||
probs = F.softmax(out, dim=-1)
|
||||
|
||||
if self.charMode:
|
||||
if self.itos[lastChar] == "\n":
|
||||
top_p = top_p_newline
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
|
||||
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
|
||||
probs = probs.numpy()
|
||||
sorted_probs = np.sort(probs)[::-1]
|
||||
cumulative_probs = np.cumsum(sorted_probs)
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
probs[probs < cutoff] = 0
|
||||
if temperature != 1.0:
|
||||
probs = probs.pow(1.0 / temperature)
|
||||
probs = probs / np.sum(probs)
|
||||
out = np.random.choice(a=len(probs), p=probs)
|
||||
return out
|
||||
else:
|
||||
sorted_probs = torch.sort(probs, descending=True)[0]
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
probs[probs < cutoff] = 0
|
||||
if temperature != 1.0:
|
||||
probs = probs.pow(1.0 / temperature)
|
||||
out = torch.multinomial(probs, num_samples=1)[0]
|
||||
return out
|
||||
|
||||
|
||||
def MaybeIsPrime(number):
|
||||
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def FermatPrimalityTest(number):
|
||||
if number > 1:
|
||||
for time in range(3):
|
||||
randomNumber = random.randint(2, number) - 1
|
||||
if pow(randomNumber, number - 1, number) != 1:
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def MillerRabinPrimalityTest(number):
|
||||
if number == 2:
|
||||
return True
|
||||
elif number == 1 or number % 2 == 0:
|
||||
return False
|
||||
oddPartOfNumber = number - 1
|
||||
timesTwoDividNumber = 0
|
||||
while oddPartOfNumber % 2 == 0:
|
||||
oddPartOfNumber = oddPartOfNumber // 2
|
||||
timesTwoDividNumber = timesTwoDividNumber + 1
|
||||
|
||||
for time in range(3):
|
||||
while True:
|
||||
randomNumber = random.randint(2, number) - 1
|
||||
if randomNumber != 0 and randomNumber != 1:
|
||||
break
|
||||
|
||||
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
|
||||
|
||||
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
||||
iterationNumber = 1
|
||||
|
||||
while (iterationNumber <= timesTwoDividNumber - 1) and (
|
||||
randomNumberWithPower != number - 1
|
||||
):
|
||||
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
||||
iterationNumber = iterationNumber + 1
|
||||
if randomNumberWithPower != (number - 1):
|
||||
return False
|
||||
|
||||
return True
|
||||
436
finetune/lora/v5/train.py
vendored
Normal file
436
finetune/lora/v5/train.py
vendored
Normal file
@@ -0,0 +1,436 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentParser
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
import pytorch_lightning as pl
|
||||
|
||||
rank_zero_info("########## work in progress ##########")
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
||||
parser.add_argument(
|
||||
"--wandb", default="", type=str
|
||||
) # wandb project name. if "" then don't use wandb
|
||||
parser.add_argument("--proj_dir", default="out", type=str)
|
||||
parser.add_argument("--random_seed", default="-1", type=int)
|
||||
|
||||
parser.add_argument("--data_file", default="", type=str)
|
||||
parser.add_argument("--data_type", default="utf-8", type=str)
|
||||
parser.add_argument(
|
||||
"--vocab_size", default=0, type=int
|
||||
) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
||||
|
||||
parser.add_argument("--ctx_len", default=1024, type=int)
|
||||
parser.add_argument(
|
||||
"--epoch_steps", default=1000, type=int
|
||||
) # a mini "epoch" has [epoch_steps] steps
|
||||
parser.add_argument(
|
||||
"--epoch_count", default=500, type=int
|
||||
) # train for this many "epochs". will continue afterwards with lr = lr_final
|
||||
parser.add_argument(
|
||||
"--epoch_begin", default=0, type=int
|
||||
) # if you load a model trained for x "epochs", set epoch_begin = x
|
||||
parser.add_argument(
|
||||
"--epoch_save", default=5, type=int
|
||||
) # save the model every [epoch_save] "epochs"
|
||||
|
||||
parser.add_argument(
|
||||
"--micro_bsz", default=12, type=int
|
||||
) # micro batch size (batch size per GPU)
|
||||
parser.add_argument("--n_layer", default=6, type=int)
|
||||
parser.add_argument("--n_embd", default=512, type=int)
|
||||
parser.add_argument("--dim_att", default=0, type=int)
|
||||
parser.add_argument("--dim_ffn", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--pre_ffn", default=0, type=int
|
||||
) # replace first att layer by ffn (sometimes better)
|
||||
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
||||
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
||||
parser.add_argument(
|
||||
"--tiny_att_layer", default=-999, type=int
|
||||
) # tiny attention @ which layer
|
||||
|
||||
parser.add_argument(
|
||||
"--lr_init", default=6e-4, type=float
|
||||
) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
||||
parser.add_argument("--lr_final", default=1e-5, type=float)
|
||||
parser.add_argument(
|
||||
"--warmup_steps", default=-1, type=int
|
||||
) # try 50 if you load a model
|
||||
parser.add_argument("--beta1", default=0.9, type=float)
|
||||
parser.add_argument(
|
||||
"--beta2", default=0.99, type=float
|
||||
) # use 0.999 when your model is close to convergence
|
||||
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
||||
parser.add_argument(
|
||||
"--grad_cp", default=0, type=int
|
||||
) # gradient checkpt: saves VRAM, but slower
|
||||
parser.add_argument(
|
||||
"--dropout", default=0, type=float
|
||||
) # try 0.01 / 0.02 / 0.05 / 0.1
|
||||
parser.add_argument(
|
||||
"--weight_decay", default=0, type=float
|
||||
) # try 0.1 / 0.01 / 0.001
|
||||
parser.add_argument("--weight_decay_final", default=-1, type=float)
|
||||
|
||||
parser.add_argument(
|
||||
"--my_pile_version", default=1, type=int
|
||||
) # my special pile version
|
||||
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
||||
parser.add_argument(
|
||||
"--my_pile_shift", default=-1, type=int
|
||||
) # my special pile mode - text shift
|
||||
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
||||
parser.add_argument(
|
||||
"--layerwise_lr", default=1, type=int
|
||||
) # layerwise lr for faster convergence (but slower it/s)
|
||||
parser.add_argument(
|
||||
"--ds_bucket_mb", default=200, type=int
|
||||
) # deepspeed bucket size in MB. 200 seems enough
|
||||
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
||||
|
||||
parser.add_argument("--my_sample_len", default=0, type=int)
|
||||
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
||||
parser.add_argument("--my_att_shift", default=1, type=int)
|
||||
parser.add_argument(
|
||||
"--head_size_a", default=64, type=int
|
||||
) # can try larger values for larger models
|
||||
parser.add_argument("--head_size_divisor", default=8, type=int)
|
||||
parser.add_argument("--my_pos_emb", default=0, type=int)
|
||||
parser.add_argument("--load_partial", default=0, type=int)
|
||||
parser.add_argument("--magic_prime", default=0, type=int)
|
||||
parser.add_argument("--my_qa_mask", default=0, type=int)
|
||||
parser.add_argument("--my_random_steps", default=0, type=int)
|
||||
parser.add_argument("--my_testing", default="", type=str)
|
||||
parser.add_argument("--my_exit", default=99999999, type=int)
|
||||
parser.add_argument("--my_exit_tokens", default=0, type=int)
|
||||
|
||||
# LORA
|
||||
parser.add_argument("--emb", action="store_true")
|
||||
parser.add_argument("--lora", action="store_true")
|
||||
parser.add_argument("--lora_load", default="", type=str)
|
||||
parser.add_argument("--lora_r", default=8, type=int)
|
||||
parser.add_argument("--lora_alpha", default=32, type=float)
|
||||
parser.add_argument("--lora_dropout", default=0.01, type=float)
|
||||
parser.add_argument("--lora_parts", default="att,ln,time", type=str)
|
||||
|
||||
if pl.__version__[0] == "2":
|
||||
parser.add_argument("--accelerator", default="gpu", type=str)
|
||||
parser.add_argument("--strategy", default="auto", type=str)
|
||||
parser.add_argument("--devices", default=1, type=int)
|
||||
parser.add_argument("--num_nodes", default=1, type=int)
|
||||
parser.add_argument("--precision", default="fp16", type=str)
|
||||
parser.add_argument("--accumulate_grad_batches", default=1, type=int)
|
||||
else:
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
########################################################################################################
|
||||
|
||||
import os, warnings, math, datetime, sys, time
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
if "deepspeed" in args.strategy:
|
||||
import deepspeed
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
if args.random_seed >= 0:
|
||||
print(
|
||||
f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n"
|
||||
* 3
|
||||
)
|
||||
seed_everything(args.random_seed)
|
||||
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
warnings.filterwarnings(
|
||||
"ignore", ".*Consider increasing the value of the `num_workers` argument*"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore", ".*The progress bar already tracks a metric with the*"
|
||||
)
|
||||
# os.environ["WDS_SHOW_SEED"] = "1"
|
||||
|
||||
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
args.enable_checkpointing = False
|
||||
args.replace_sampler_ddp = False
|
||||
args.logger = False
|
||||
args.gradient_clip_val = 1.0
|
||||
args.num_sanity_val_steps = 0
|
||||
args.check_val_every_n_epoch = int(1e20)
|
||||
args.log_every_n_steps = int(1e20)
|
||||
args.max_epochs = args.epoch_count # -1 continue forever
|
||||
args.betas = (args.beta1, args.beta2)
|
||||
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
||||
os.environ["RWKV_MY_TESTING"] = args.my_testing
|
||||
os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
|
||||
if args.dim_att <= 0:
|
||||
args.dim_att = args.n_embd
|
||||
if args.dim_ffn <= 0:
|
||||
args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
|
||||
|
||||
if args.data_type == "wds_img":
|
||||
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
||||
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
||||
else:
|
||||
args.run_name = (
|
||||
f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
||||
)
|
||||
if not os.path.exists(args.proj_dir):
|
||||
os.makedirs(args.proj_dir)
|
||||
|
||||
if args.my_pile_stage > 0:
|
||||
magic_prime_bak = args.magic_prime
|
||||
|
||||
if args.my_pile_shift < 0:
|
||||
args.my_pile_shift = 0
|
||||
|
||||
if magic_prime_bak > 0:
|
||||
args.magic_prime = magic_prime_bak
|
||||
if args.my_qa_mask == 2:
|
||||
args.epoch_count = 2 * args.magic_prime // 40320
|
||||
else:
|
||||
args.epoch_count = args.magic_prime // 40320
|
||||
|
||||
args.epoch_steps = 40320 // args.real_bsz
|
||||
assert args.epoch_steps * args.real_bsz == 40320
|
||||
# if args.my_pile_stage == 2:
|
||||
# assert args.lr_final == args.lr_init
|
||||
if args.my_pile_stage >= 2: # find latest saved model
|
||||
list_p = []
|
||||
for p in os.listdir(args.proj_dir):
|
||||
if p.startswith("rwkv") and p.endswith(".pth"):
|
||||
p = ((p.split("-"))[1].split("."))[0]
|
||||
if p != "final":
|
||||
if p == "init":
|
||||
p = -1
|
||||
else:
|
||||
p = int(p)
|
||||
list_p += [p]
|
||||
list_p.sort()
|
||||
max_p = list_p[-1]
|
||||
if len(list_p) > 1:
|
||||
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
|
||||
if max_p == -1:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||
else:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||
if args.warmup_steps < 0:
|
||||
if args.my_pile_stage == 2:
|
||||
args.warmup_steps = 10
|
||||
else:
|
||||
args.warmup_steps = 30
|
||||
args.epoch_begin = max_p + 1
|
||||
|
||||
samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
||||
try:
|
||||
deepspeed_version = deepspeed.__version__
|
||||
except:
|
||||
deepspeed_version = None
|
||||
pass
|
||||
rank_zero_info(
|
||||
f"""
|
||||
############################################################################
|
||||
#
|
||||
# RWKV-5 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
|
||||
#
|
||||
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
||||
#
|
||||
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1}, save every {args.epoch_save} epoch
|
||||
#
|
||||
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
||||
#
|
||||
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
||||
#
|
||||
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
||||
#
|
||||
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
|
||||
# Found deepspeed {deepspeed_version}, recommend 0.7.0 (faster than newer versions)
|
||||
# Found pytorch_lightning {pl.__version__}, recommend 1.9.5
|
||||
#
|
||||
############################################################################
|
||||
"""
|
||||
)
|
||||
rank_zero_info(str(vars(args)) + "\n")
|
||||
|
||||
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "uint16"]
|
||||
|
||||
if args.lr_final == 0 or args.lr_init == 0:
|
||||
rank_zero_info(
|
||||
"\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n"
|
||||
)
|
||||
|
||||
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
||||
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
||||
if args.precision == "fp32":
|
||||
for i in range(10):
|
||||
rank_zero_info(
|
||||
"\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n"
|
||||
)
|
||||
if args.precision == "fp16":
|
||||
rank_zero_info(
|
||||
"\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n"
|
||||
)
|
||||
|
||||
os.environ["RWKV_JIT_ON"] = "0"
|
||||
if "deepspeed_stage_3" in args.strategy:
|
||||
os.environ["RWKV_JIT_ON"] = "0"
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
if args.precision == "fp32":
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
else:
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if "32" in args.precision:
|
||||
args.precision = 32
|
||||
elif args.precision == "fp16":
|
||||
args.precision = 16
|
||||
else:
|
||||
args.precision = "bf16"
|
||||
|
||||
########################################################################################################
|
||||
|
||||
from src.trainer import train_callback, generate_init_weight
|
||||
from src.dataset import MyDataset
|
||||
|
||||
train_data = MyDataset(args)
|
||||
args.vocab_size = train_data.vocab_size
|
||||
|
||||
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
||||
|
||||
if args.lora:
|
||||
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
||||
LORA_CONFIG["r"] = args.lora_r
|
||||
LORA_CONFIG["alpha"] = args.lora_alpha
|
||||
LORA_CONFIG["dropout"] = args.lora_dropout
|
||||
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
|
||||
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||
model = RWKV(args)
|
||||
# only train lora parameters
|
||||
if args.lora:
|
||||
model.requires_grad_(False)
|
||||
for name, module in model.named_modules():
|
||||
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
|
||||
print(f" LoRA additionally training module {name}")
|
||||
for pname, param in module.named_parameters():
|
||||
param.requires_grad = "lora_" in pname
|
||||
elif enable_ln_finetune and ".ln" in name:
|
||||
print(f" LoRA additionally training module {name}")
|
||||
for param in module.parameters():
|
||||
param.requires_grad = True
|
||||
elif enable_time_finetune and any(
|
||||
n.startswith("time") for n, _ in module.named_parameters()
|
||||
):
|
||||
for pname, param in module.named_parameters():
|
||||
if pname.startswith("time"):
|
||||
print(f" LoRA additionally training parameter {pname}")
|
||||
param.requires_grad = True
|
||||
|
||||
if (
|
||||
len(args.load_model) == 0 or args.my_pile_stage == 1
|
||||
): # shall we build the initial weights?
|
||||
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
||||
generate_init_weight(model, init_weight_name) # save initial weights
|
||||
args.load_model = init_weight_name
|
||||
|
||||
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
||||
try:
|
||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
load_keys = list(load_dict.keys())
|
||||
for k in load_keys:
|
||||
if k.startswith("_forward_module."):
|
||||
load_dict[k.replace("_forward_module.", "")] = load_dict[k]
|
||||
del load_dict[k]
|
||||
except:
|
||||
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
||||
if args.my_pile_stage >= 2: # try again using another checkpoint
|
||||
max_p = args.my_pile_prev_p
|
||||
if max_p == -1:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||
else:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||
args.epoch_begin = max_p + 1
|
||||
rank_zero_info(f"Trying {args.load_model}")
|
||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
|
||||
if args.load_partial == 1:
|
||||
load_keys = load_dict.keys()
|
||||
for k in model.state_dict():
|
||||
if k not in load_keys:
|
||||
load_dict[k] = model.state_dict()[k]
|
||||
# model.load_state_dict(load_dict)
|
||||
|
||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
if os.path.isfile(args.lora_load):
|
||||
model.load_state_dict(
|
||||
torch.load(args.lora_load, map_location="cpu"), strict=False
|
||||
)
|
||||
|
||||
if pl.__version__[0] == "2":
|
||||
trainer = Trainer(
|
||||
accelerator=args.accelerator,
|
||||
strategy=args.strategy,
|
||||
devices=args.devices,
|
||||
num_nodes=args.num_nodes,
|
||||
precision=args.precision,
|
||||
logger=args.logger,
|
||||
callbacks=[train_callback(args)],
|
||||
max_epochs=args.max_epochs,
|
||||
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
||||
num_sanity_val_steps=args.num_sanity_val_steps,
|
||||
log_every_n_steps=args.log_every_n_steps,
|
||||
enable_checkpointing=args.enable_checkpointing,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
gradient_clip_val=args.gradient_clip_val,
|
||||
)
|
||||
else:
|
||||
trainer = Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[train_callback(args)],
|
||||
)
|
||||
|
||||
if trainer.global_rank == 0:
|
||||
for n in model.state_dict():
|
||||
shape = model.state_dict()[n].shape
|
||||
shape = [i for i in shape if i != 1]
|
||||
if len(shape) > 1:
|
||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
|
||||
else:
|
||||
print(f"{str(shape[0]).ljust(5)} {n}")
|
||||
|
||||
if "deepspeed" in args.strategy:
|
||||
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
)
|
||||
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
)
|
||||
|
||||
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
||||
data_loader = DataLoader(
|
||||
train_data,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
batch_size=args.micro_bsz,
|
||||
num_workers=1,
|
||||
persistent_workers=False,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
trainer.fit(model, data_loader)
|
||||
32
frontend/package-lock.json
generated
32
frontend/package-lock.json
generated
@@ -13,11 +13,13 @@
|
||||
"@magenta/music": "^1.23.1",
|
||||
"@microsoft/fetch-event-source": "^2.0.1",
|
||||
"@primer/octicons-react": "^19.1.0",
|
||||
"abcjs": "^6.2.3",
|
||||
"chart.js": "^4.3.0",
|
||||
"classnames": "^2.3.2",
|
||||
"file-saver": "^2.0.5",
|
||||
"html-midi-player": "^1.5.0",
|
||||
"i18next": "^22.4.15",
|
||||
"lodash-es": "^4.17.21",
|
||||
"mobx": "^6.9.0",
|
||||
"mobx-react-lite": "^3.4.3",
|
||||
"pdfjs-dist": "^4.0.189",
|
||||
@@ -40,6 +42,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/file-saver": "^2.0.7",
|
||||
"@types/lodash-es": "^4.17.12",
|
||||
"@types/react": "^18.2.6",
|
||||
"@types/react-beautiful-dnd": "^13.1.4",
|
||||
"@types/react-dom": "^18.2.4",
|
||||
@@ -2533,6 +2536,21 @@
|
||||
"hoist-non-react-statics": "^3.3.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/lodash": {
|
||||
"version": "4.14.202",
|
||||
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.14.202.tgz",
|
||||
"integrity": "sha512-OvlIYQK9tNneDlS0VN54LLd5uiPCBOp7gS5Z0f1mjoJYBrtStzgmJBxONW3U6OZqdtNzZPmn9BS/7WI7BFFcFQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@types/lodash-es": {
|
||||
"version": "4.17.12",
|
||||
"resolved": "https://registry.npmjs.org/@types/lodash-es/-/lodash-es-4.17.12.tgz",
|
||||
"integrity": "sha512-0NgftHUcV4v34VhXm8QBSftKVXtbkBG3ViCjs6+eJ5a6y6Mi/jiFGPc1sC7QK+9BFhWrURE3EOggmWaSxL9OzQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/lodash": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/long": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz",
|
||||
@@ -2673,6 +2691,15 @@
|
||||
"integrity": "sha512-nne9/IiQ/hzIhY6pdDnbBtz7DjPTKrY00P/zvPSm5pOFkl6xuGrGnXn/VtTNNfNtAfZ9/1RtehkszU9qcTii0Q==",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/abcjs": {
|
||||
"version": "6.2.3",
|
||||
"resolved": "https://registry.npmjs.org/abcjs/-/abcjs-6.2.3.tgz",
|
||||
"integrity": "sha512-epu8C1yRkxV7Ss9hS0Bu72rairl1p2sR3hviVowjtdDJvb5GRE0SrB4TtN4HBbaoYhvxGnSZQxGULfQlW3o3RQ==",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/paulrosen"
|
||||
}
|
||||
},
|
||||
"node_modules/acorn": {
|
||||
"version": "7.4.1",
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-7.4.1.tgz",
|
||||
@@ -4210,6 +4237,11 @@
|
||||
"integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/lodash-es": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz",
|
||||
"integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="
|
||||
},
|
||||
"node_modules/long": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz",
|
||||
|
||||
@@ -14,11 +14,13 @@
|
||||
"@magenta/music": "^1.23.1",
|
||||
"@microsoft/fetch-event-source": "^2.0.1",
|
||||
"@primer/octicons-react": "^19.1.0",
|
||||
"abcjs": "^6.2.3",
|
||||
"chart.js": "^4.3.0",
|
||||
"classnames": "^2.3.2",
|
||||
"file-saver": "^2.0.5",
|
||||
"html-midi-player": "^1.5.0",
|
||||
"i18next": "^22.4.15",
|
||||
"lodash-es": "^4.17.21",
|
||||
"mobx": "^6.9.0",
|
||||
"mobx-react-lite": "^3.4.3",
|
||||
"pdfjs-dist": "^4.0.189",
|
||||
@@ -41,6 +43,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/file-saver": "^2.0.7",
|
||||
"@types/lodash-es": "^4.17.12",
|
||||
"@types/react": "^18.2.6",
|
||||
"@types/react-beautiful-dnd": "^13.1.4",
|
||||
"@types/react-dom": "^18.2.4",
|
||||
|
||||
@@ -128,7 +128,7 @@
|
||||
"Chinese Kongfu": "中国武術",
|
||||
"Allow external access to the API (service must be restarted)": "APIへの外部アクセスを許可する (サービスを再起動する必要があります)",
|
||||
"Custom": "カスタム",
|
||||
"CUDA (Beta, Faster)": "CUDA (ベータ、高速)",
|
||||
"CUDA (Beta, Faster)": "CUDA (Beta, 高速)",
|
||||
"Reset All Configs": "すべての設定をリセット",
|
||||
"Cancel": "キャンセル",
|
||||
"Confirm": "確認",
|
||||
@@ -162,7 +162,7 @@
|
||||
"Memory is not enough, try to increase the virtual memory or use a smaller model.": "メモリが不足しています。仮想メモリを増やすか、もしくは小さなモデルを使ってみてください",
|
||||
"Bad PyTorch version, please reinstall PyTorch with cuda.": "不適切なPyTorchのバージョンです。cudaと共にPyTorchを再インストールしてください。",
|
||||
"The model file is corrupted, please download again.": "モデルファイルが破損しています。再度ダウンロードしてください。",
|
||||
"Found no NVIDIA driver, please install the latest driver.": "NVIDIAのドライバが見つかりません。最新版のドライバをインストールしてください。",
|
||||
"Found no NVIDIA driver, please install the latest driver. If you are not using an Nvidia GPU, please switch the 'Strategy' to WebGPU or CPU in the Configs page.": "NVIDIAのドライバが見つかりません。最新版のドライバをインストールしてください。NvidiaのGPUを使用していない場合は、設定ページで\"Strategy\"をWebGPUまたはCPUに切り替えてください。",
|
||||
"VRAM is not enough, please reduce stored layers or use a lower precision in Configs page.": "VRAMが足りません。設定ページで保存されているレイヤーを減らすか、精度を下げてください。",
|
||||
"Failed to enable custom CUDA kernel, ninja is required to load C++ extensions. You may be using the CPU version of PyTorch, please reinstall PyTorch with CUDA. Or if you are using a custom Python interpreter, you must compile the CUDA kernel by yourself or disable Custom CUDA kernel acceleration.": "カスタムCUDAカーネルの有効化に失敗しました。C++拡張を読み込むためにはNinjaが必要です。あなたは恐らくCPU版のPyTorchを使用しており、CUDA版のPyTorchを再インストールする必要があります。または、あなたがカスタムPythonインタプリタを使用している場合は、CUDAカーネルを自分でコンパイルするか、カスタムCUDAカーネルのアクセラレーションを無効にする必要があります。",
|
||||
"Presets": "プリセット",
|
||||
@@ -250,13 +250,13 @@
|
||||
"VRAM": "VRAM",
|
||||
"GPU Usage": "GPU使用率",
|
||||
"Use Custom Tokenizer": "カスタムトークナイザーを使用する",
|
||||
"Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json)": "トークナイザーパス (例: backend-python/rwkv_pip/20B_tokenizer.json)",
|
||||
"Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json or rwkv_vocab_v20230424.txt)": "トークナイザーパス (例: backend-python/rwkv_pip/20B_tokenizer.json または rwkv_vocab_v20230424.txt)",
|
||||
"User Name": "ユーザー名",
|
||||
"Assistant Name": "アシスタント名",
|
||||
"Insert default system prompt at the beginning": "最初にデフォルトのシステムプロンプトを挿入",
|
||||
"Format Content": "内容フォーマットの規格化",
|
||||
"Add An Attachment (Accepts pdf, txt)": "添付ファイルを追加 (pdf, txtを受け付けます)",
|
||||
"Uploading Attachment": "添付ファイルアップロード中",
|
||||
"Processing Attachment": "添付ファイルを処理中",
|
||||
"Remove Attachment": "添付ファイルを削除",
|
||||
"The content of file": "ファイル",
|
||||
"is as follows. When replying to me, consider the file content and respond accordingly:": "の内容は以下の通りです。私に返信する際は、ファイルの内容を考慮して適切に返信してください:",
|
||||
@@ -295,5 +295,34 @@
|
||||
"Sax": "サックス",
|
||||
"Flute": "フルート",
|
||||
"Lead": "リード",
|
||||
"Pad": "パッド"
|
||||
"Pad": "パッド",
|
||||
"MIDI Input": "MIDI入力",
|
||||
"Select the MIDI input device to be used.": "使用するMIDI入力デバイスを選択します。",
|
||||
"Start Time": "開始時間",
|
||||
"Content Duration": "内容の長さ",
|
||||
"Please select a MIDI device first": "まずMIDIデバイスを選択してください",
|
||||
"Piano is the main instrument": "ピアノはメインの楽器です",
|
||||
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Lossが大きすぎます、トレーニングデータを確認し、GPUドライバが最新であることを確認してください。",
|
||||
"This version of RWKV is not supported yet.": "このバージョンのRWKVはまだサポートされていません。",
|
||||
"Main": "メイン",
|
||||
"Finetuned": "微調整",
|
||||
"Global": "グローバル",
|
||||
"Local": "ローカル",
|
||||
"CN": "中国語",
|
||||
"JP": "日本語",
|
||||
"Music": "音楽",
|
||||
"Other": "その他",
|
||||
"Role Play": "ロールプレイ",
|
||||
"Recommended": "おすすめ",
|
||||
"Import MIDI": "MIDIをインポート",
|
||||
"Current Instrument": "現在の楽器",
|
||||
"Please convert model to GGML format first": "モデルをGGML形式に変換してください",
|
||||
"Convert To GGML Format": "GGML形式に変換",
|
||||
"CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 高速)",
|
||||
"Play With External Player": "外部プレーヤーで再生",
|
||||
"Core API URL": "コアAPI URL",
|
||||
"Override core API URL(/chat/completions and /completions). If you don't know what this is, leave it blank.": "コアAPI URLを上書きします(/chat/completions と /completions)。何であるかわからない場合は空白のままにしてください。",
|
||||
"Please change Strategy to CPU (rwkv.cpp) to use ggml format": "StrategyをCPU (rwkv.cpp)に変更して、ggml形式を使用してください",
|
||||
"Only Auto Play Generated Content": "生成されたコンテンツのみ自動再生",
|
||||
"Model has been converted and does not match current strategy. If you are using a new strategy, re-convert the model.": "モデルが変換され、現在の戦略と一致しません。新しい戦略を使用している場合は、モデルを再変換してください。"
|
||||
}
|
||||
@@ -162,7 +162,7 @@
|
||||
"Memory is not enough, try to increase the virtual memory or use a smaller model.": "内存不足,尝试增加虚拟内存,或使用一个更小规模的模型",
|
||||
"Bad PyTorch version, please reinstall PyTorch with cuda.": "错误的PyTorch版本,请重新安装CUDA版本的PyTorch",
|
||||
"The model file is corrupted, please download again.": "模型文件损坏,请重新下载",
|
||||
"Found no NVIDIA driver, please install the latest driver.": "没有找到NVIDIA驱动,请安装最新驱动",
|
||||
"Found no NVIDIA driver, please install the latest driver. If you are not using an Nvidia GPU, please switch the 'Strategy' to WebGPU or CPU in the Configs page.": "没有找到NVIDIA驱动,请安装最新驱动。如果你没有使用Nvidia显卡,请在配置页面将“Strategy”改为WebGPU或CPU",
|
||||
"VRAM is not enough, please reduce stored layers or use a lower precision in Configs page.": "显存不足,请在配置页面减少载入显存层数,或使用更低的精度",
|
||||
"Failed to enable custom CUDA kernel, ninja is required to load C++ extensions. You may be using the CPU version of PyTorch, please reinstall PyTorch with CUDA. Or if you are using a custom Python interpreter, you must compile the CUDA kernel by yourself or disable Custom CUDA kernel acceleration.": "自定义CUDA算子开启失败,需要安装Ninja来读取C++扩展。你可能正在使用CPU版本的PyTorch,请重新安装CUDA版本的PyTorch。如果你正在使用自定义Python解释器,你必须自己编译CUDA算子或禁用自定义CUDA算子加速",
|
||||
"Presets": "预设",
|
||||
@@ -250,13 +250,13 @@
|
||||
"VRAM": "显存",
|
||||
"GPU Usage": "GPU占用",
|
||||
"Use Custom Tokenizer": "使用自定义Tokenizer",
|
||||
"Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json)": "Tokenizer路径 (例如: backend-python/rwkv_pip/20B_tokenizer.json)",
|
||||
"Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json or rwkv_vocab_v20230424.txt)": "Tokenizer路径 (例如: backend-python/rwkv_pip/20B_tokenizer.json 或 rwkv_vocab_v20230424.txt)",
|
||||
"User Name": "用户名称",
|
||||
"Assistant Name": "AI名称",
|
||||
"Insert default system prompt at the beginning": "在开头自动插入默认系统提示",
|
||||
"Format Content": "规范格式",
|
||||
"Add An Attachment (Accepts pdf, txt)": "添加一个附件 (支持pdf, txt)",
|
||||
"Uploading Attachment": "正在上传附件",
|
||||
"Processing Attachment": "正在处理附件",
|
||||
"Remove Attachment": "移除附件",
|
||||
"The content of file": "文件",
|
||||
"is as follows. When replying to me, consider the file content and respond accordingly:": "内容如下。回复时考虑文件内容并做出相应回复:",
|
||||
@@ -295,5 +295,34 @@
|
||||
"Sax": "萨克斯",
|
||||
"Flute": "长笛",
|
||||
"Lead": "主音",
|
||||
"Pad": "和音"
|
||||
"Pad": "和音",
|
||||
"MIDI Input": "MIDI输入",
|
||||
"Select the MIDI input device to be used.": "选择要使用的MIDI输入设备",
|
||||
"Start Time": "开始时间",
|
||||
"Content Duration": "内容时长",
|
||||
"Please select a MIDI device first": "请先选择一个MIDI设备",
|
||||
"Piano is the main instrument": "钢琴为主",
|
||||
"Loss is too high, please check the training data, and ensure your gpu driver is up to date.": "Loss过高,请检查训练数据,并确保你的显卡驱动是最新的",
|
||||
"This version of RWKV is not supported yet.": "暂不支持此版本的RWKV",
|
||||
"Main": "主干",
|
||||
"Finetuned": "微调",
|
||||
"Global": "全球",
|
||||
"Local": "本地",
|
||||
"CN": "中文",
|
||||
"JP": "日文",
|
||||
"Music": "音乐",
|
||||
"Other": "其他",
|
||||
"Role Play": "角色扮演",
|
||||
"Recommended": "推荐",
|
||||
"Import MIDI": "导入MIDI",
|
||||
"Current Instrument": "当前乐器",
|
||||
"Please convert model to GGML format first": "请先将模型转换为GGML格式",
|
||||
"Convert To GGML Format": "转换为GGML格式",
|
||||
"CPU (rwkv.cpp, Faster)": "CPU (rwkv.cpp, 更快)",
|
||||
"Play With External Player": "使用外部播放器播放",
|
||||
"Core API URL": "核心 API URL",
|
||||
"Override core API URL(/chat/completions and /completions). If you don't know what this is, leave it blank.": "覆盖核心的 API URL (/chat/completions 和 /completions)。如果你不知道这是什么,请留空",
|
||||
"Please change Strategy to CPU (rwkv.cpp) to use ggml format": "请将Strategy改为CPU (rwkv.cpp)以使用ggml格式",
|
||||
"Only Auto Play Generated Content": "仅自动播放新生成的内容",
|
||||
"Model has been converted and does not match current strategy. If you are using a new strategy, re-convert the model.": "所选模型已被转换过,并且不匹配当前的Strategy。如果你正在使用新的Strategy,请重新转换模型"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import { FC } from 'react';
|
||||
import { observer } from 'mobx-react-lite';
|
||||
import { Dropdown, Option } from '@fluentui/react-components';
|
||||
import { Dropdown, Option, PresenceBadge } from '@fluentui/react-components';
|
||||
import commonStore from '../stores/commonStore';
|
||||
|
||||
export const ConfigSelector: FC<{ size?: 'small' | 'medium' | 'large' }> = observer(({ size }) => {
|
||||
@@ -12,7 +12,13 @@ export const ConfigSelector: FC<{ size?: 'small' | 'medium' | 'large' }> = obser
|
||||
commonStore.setCurrentConfigIndex(Number(data.optionValue));
|
||||
}}>
|
||||
{commonStore.modelConfigs.map((config, index) =>
|
||||
<Option key={index} value={index.toString()}>{config.name}</Option>
|
||||
<Option key={index} value={index.toString()} text={config.name}>
|
||||
<div className="flex justify-between grow">
|
||||
{config.name}
|
||||
{commonStore.modelSourceList.find(item => item.name === config.modelParameters.modelName)?.isComplete
|
||||
&& <PresenceBadge status="available" />}
|
||||
</div>
|
||||
</Option>
|
||||
)}
|
||||
</Dropdown>;
|
||||
});
|
||||
@@ -17,7 +17,9 @@ import { ToolTipButton } from './ToolTipButton';
|
||||
import { Play16Regular, Stop16Regular } from '@fluentui/react-icons';
|
||||
import { useNavigate } from 'react-router';
|
||||
import { WindowShow } from '../../wailsjs/runtime';
|
||||
import { convertToSt } from '../utils/convert-to-st';
|
||||
import { convertToGGML, convertToSt } from '../utils/convert-model';
|
||||
import { Precision } from '../types/configs';
|
||||
import { defaultCompositionABCPrompt, defaultCompositionPrompt } from '../pages/defaultConfigs';
|
||||
|
||||
const mainButtonText = {
|
||||
[ModelStatus.Offline]: 'Run',
|
||||
@@ -47,6 +49,8 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
|
||||
const modelConfig = commonStore.getCurrentModelConfig();
|
||||
const webgpu = modelConfig.modelParameters.device === 'WebGPU';
|
||||
const webgpuPython = modelConfig.modelParameters.device === 'WebGPU (Python)';
|
||||
const cpp = modelConfig.modelParameters.device === 'CPU (rwkv.cpp)';
|
||||
let modelName = '';
|
||||
let modelPath = '';
|
||||
if (modelConfig && modelConfig.modelParameters) {
|
||||
@@ -75,7 +79,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
});
|
||||
};
|
||||
|
||||
if (webgpu) {
|
||||
if (webgpu || webgpuPython) {
|
||||
if (!['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
|
||||
const stModelPath = modelPath.replace(/\.pth$/, '.st');
|
||||
if (await FileExists(stModelPath)) {
|
||||
@@ -90,7 +94,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
return;
|
||||
} else {
|
||||
toastWithButton(t('Please convert model to safe tensors format first'), t('Convert'), () => {
|
||||
convertToSt(navigate, modelConfig);
|
||||
convertToSt(modelConfig, navigate);
|
||||
});
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
return;
|
||||
@@ -98,7 +102,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
}
|
||||
}
|
||||
|
||||
if (!webgpu) {
|
||||
if (!webgpu && !webgpuPython) {
|
||||
if (['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
|
||||
toast(t('Please change Strategy to WebGPU to use safetensors format'), { type: 'error' });
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
@@ -112,6 +116,38 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
return;
|
||||
}
|
||||
|
||||
if (cpp) {
|
||||
if (!['.bin'].some(ext => modelPath.endsWith(ext))) {
|
||||
const precision: Precision = modelConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16';
|
||||
const ggmlModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`);
|
||||
if (await FileExists(ggmlModelPath)) {
|
||||
modelPath = ggmlModelPath;
|
||||
} else if (!await FileExists(modelPath)) {
|
||||
showDownloadPrompt(t('Model file not found'), modelName);
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
return;
|
||||
} else if (!currentModelSource?.isComplete) {
|
||||
showDownloadPrompt(t('Model file download is not complete'), modelName);
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
return;
|
||||
} else {
|
||||
toastWithButton(t('Please convert model to GGML format first'), t('Convert'), () => {
|
||||
convertToGGML(modelConfig, navigate);
|
||||
});
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!cpp) {
|
||||
if (['.bin'].some(ext => modelPath.endsWith(ext))) {
|
||||
toast(t('Please change Strategy to CPU (rwkv.cpp) to use ggml format'), { type: 'error' });
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!await FileExists(modelPath)) {
|
||||
showDownloadPrompt(t('Model file not found'), modelName);
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
@@ -142,7 +178,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
const isUsingCudaBeta = modelConfig.modelParameters.device === 'CUDA-Beta';
|
||||
|
||||
startServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
|
||||
!!modelConfig.enableWebUI, isUsingCudaBeta
|
||||
!!modelConfig.enableWebUI, isUsingCudaBeta, cpp, webgpuPython
|
||||
).catch((e) => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
@@ -169,7 +205,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
});
|
||||
}
|
||||
commonStore.setStatus({ status: ModelStatus.Loading });
|
||||
const loadingId = toast(t('Loading Model'), { type: 'info' });
|
||||
const loadingId = toast(t('Loading Model'), { type: 'info', autoClose: false });
|
||||
if (!webgpu) {
|
||||
updateConfig({
|
||||
max_tokens: modelConfig.apiParameters.maxResponseToken,
|
||||
@@ -182,8 +218,9 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
|
||||
const strategy = getStrategy(modelConfig);
|
||||
let customCudaFile = '';
|
||||
if ((modelConfig.modelParameters.device.includes('CUDA') || modelConfig.modelParameters.device === 'Custom')
|
||||
&& modelConfig.modelParameters.useCustomCuda && !strategy.includes('fp32')) {
|
||||
if ((modelConfig.modelParameters.device.startsWith('CUDA') || modelConfig.modelParameters.device === 'Custom')
|
||||
&& modelConfig.modelParameters.useCustomCuda
|
||||
&& !strategy.split('->').some(s => ['cuda', 'fp32'].every(v => s.includes(v)))) {
|
||||
if (commonStore.platform === 'windows') {
|
||||
// this part is currently unused because there's no longer a need to use different kernels for different GPUs, but it might still be needed in the future
|
||||
//
|
||||
@@ -221,6 +258,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
commonStore.setStatus({ status: ModelStatus.Working });
|
||||
let buttonNameMap = {
|
||||
'novel': 'Completion',
|
||||
'abc': 'Composition',
|
||||
'midi': 'Composition'
|
||||
};
|
||||
let buttonName = 'Chat';
|
||||
@@ -228,11 +266,18 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
const buttonFn = () => {
|
||||
navigate({ pathname: '/' + buttonName.toLowerCase() });
|
||||
};
|
||||
if (modelName.toLowerCase().includes('abc') && commonStore.compositionParams.prompt === defaultCompositionPrompt) {
|
||||
commonStore.setCompositionParams({
|
||||
...commonStore.compositionParams,
|
||||
prompt: defaultCompositionABCPrompt
|
||||
});
|
||||
commonStore.setCompositionSubmittedPrompt(defaultCompositionABCPrompt);
|
||||
}
|
||||
|
||||
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'CUDA-Beta') &&
|
||||
if (modelConfig.modelParameters.device.startsWith('CUDA') &&
|
||||
modelConfig.modelParameters.storedLayers < modelConfig.modelParameters.maxStoredLayers &&
|
||||
commonStore.monitorData && commonStore.monitorData.totalVram !== 0 &&
|
||||
(commonStore.monitorData.usedVram / commonStore.monitorData.totalVram) < 0.85)
|
||||
(commonStore.monitorData.usedVram / commonStore.monitorData.totalVram) < 0.9)
|
||||
toast(t('You can increase the number of stored layers in Configs page to improve performance'), { type: 'info' });
|
||||
toastWithButton(t('Startup Completed'), t(buttonName), buttonFn, { type: 'success', autoClose: 3000 });
|
||||
} else if (r.status === 304) {
|
||||
@@ -244,9 +289,10 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
'not enough memory': 'Memory is not enough, try to increase the virtual memory or use a smaller model.',
|
||||
'not compiled with CUDA': 'Bad PyTorch version, please reinstall PyTorch with cuda.',
|
||||
'invalid header or archive is corrupted': 'The model file is corrupted, please download again.',
|
||||
'no NVIDIA driver': 'Found no NVIDIA driver, please install the latest driver.',
|
||||
'no NVIDIA driver': 'Found no NVIDIA driver, please install the latest driver. If you are not using an Nvidia GPU, please switch the \'Strategy\' to WebGPU or CPU in the Configs page.',
|
||||
'CUDA out of memory': 'VRAM is not enough, please reduce stored layers or use a lower precision in Configs page.',
|
||||
'Ninja is required to load C++ extensions': 'Failed to enable custom CUDA kernel, ninja is required to load C++ extensions. You may be using the CPU version of PyTorch, please reinstall PyTorch with CUDA. Or if you are using a custom Python interpreter, you must compile the CUDA kernel by yourself or disable Custom CUDA kernel acceleration.'
|
||||
'Ninja is required to load C++ extensions': 'Failed to enable custom CUDA kernel, ninja is required to load C++ extensions. You may be using the CPU version of PyTorch, please reinstall PyTorch with CUDA. Or if you are using a custom Python interpreter, you must compile the CUDA kernel by yourself or disable Custom CUDA kernel acceleration.',
|
||||
're-convert the model': 'Model has been converted and does not match current strategy. If you are using a new strategy, re-convert the model.'
|
||||
};
|
||||
const matchedError = Object.entries(errorsMap).find(([key, _]) => error.includes(key));
|
||||
const message = matchedError ? t(matchedError[1]) : error;
|
||||
|
||||
@@ -7,6 +7,7 @@ import { v4 as uuid } from 'uuid';
|
||||
import {
|
||||
Add16Regular,
|
||||
ArrowAutofitWidth20Regular,
|
||||
ArrowUpload16Regular,
|
||||
Delete16Regular,
|
||||
MusicNote220Regular,
|
||||
Pause16Regular,
|
||||
@@ -17,19 +18,25 @@ import {
|
||||
} from '@fluentui/react-icons';
|
||||
import { Button, Card, DialogTrigger, Slider, Text, Tooltip } from '@fluentui/react-components';
|
||||
import { useWindowSize } from 'usehooks-ts';
|
||||
import commonStore from '../../stores/commonStore';
|
||||
import commonStore, { ModelStatus } from '../../stores/commonStore';
|
||||
import classnames from 'classnames';
|
||||
import {
|
||||
InstrumentType,
|
||||
InstrumentTypeNameMap,
|
||||
InstrumentTypeTokenMap,
|
||||
MidiMessage,
|
||||
tracksMinimalTotalTime
|
||||
} from '../../types/composition';
|
||||
import { toast } from 'react-toastify';
|
||||
import { ToastOptions } from 'react-toastify/dist/types';
|
||||
import { flushMidiRecordingContent, refreshTracksTotalTime } from '../../utils';
|
||||
import { PlayNote } from '../../../wailsjs/go/backend_golang/App';
|
||||
import { t } from 'i18next';
|
||||
import {
|
||||
absPathAsset,
|
||||
flushMidiRecordingContent,
|
||||
getMidiRawContentMainInstrument,
|
||||
getMidiRawContentTime,
|
||||
getServerRoot,
|
||||
refreshTracksTotalTime
|
||||
} from '../../utils';
|
||||
import { OpenOpenFileDialog, PlayNote } from '../../../wailsjs/go/backend_golang/App';
|
||||
|
||||
const snapValue = 25;
|
||||
const minimalMoveTime = 8; // 1000/125=8ms wait_events=125
|
||||
@@ -47,52 +54,62 @@ const pixelFix = 0.5;
|
||||
const topToArrowIcon = 19;
|
||||
const arrowIconToTracks = 23;
|
||||
|
||||
type TrackProps = {
|
||||
id: string;
|
||||
right: number;
|
||||
scale: number;
|
||||
isSelected: boolean;
|
||||
onSelect: (id: string) => void;
|
||||
};
|
||||
|
||||
const displayCurrentInstrumentType = () => {
|
||||
const displayPanelId = 'instrument_panel_id';
|
||||
const content: React.ReactNode =
|
||||
<div className="flex gap-2 items-center">
|
||||
{InstrumentTypeNameMap.map((name, i) =>
|
||||
<Text key={name} style={{ whiteSpace: 'nowrap' }}
|
||||
className={commonStore.instrumentType === i ? 'text-blue-600' : ''}
|
||||
weight={commonStore.instrumentType === i ? 'bold' : 'regular'}
|
||||
size={commonStore.instrumentType === i ? 300 : 100}
|
||||
>{t(name)}</Text>)}
|
||||
</div>;
|
||||
const options: ToastOptions = {
|
||||
type: 'default',
|
||||
autoClose: 2000,
|
||||
toastId: displayPanelId,
|
||||
position: 'top-left',
|
||||
style: {
|
||||
width: 'fit-content'
|
||||
}
|
||||
};
|
||||
if (toast.isActive(displayPanelId))
|
||||
toast.update(displayPanelId, {
|
||||
render: content,
|
||||
...options
|
||||
});
|
||||
else
|
||||
toast(content, options);
|
||||
};
|
||||
|
||||
const velocityToBin = (velocity: number) => {
|
||||
velocity = Math.max(0, Math.min(velocity, velocityEvents - 1));
|
||||
const binsize = velocityEvents / (velocityBins - 1);
|
||||
return Math.ceil((velocityEvents * ((Math.pow(velocityExp, (velocity / velocityEvents)) - 1.0) / (velocityExp - 1.0))) / binsize);
|
||||
};
|
||||
|
||||
const binToVelocity = (bin: number) => {
|
||||
const binsize = velocityEvents / (velocityBins - 1);
|
||||
return Math.max(0, Math.ceil(velocityEvents * (Math.log(((velocityExp - 1) * binsize * bin) / velocityEvents + 1) / Math.log(velocityExp)) - 1));
|
||||
};
|
||||
|
||||
const tokenToMidiMessage = (token: string): MidiMessage | null => {
|
||||
if (token.startsWith('<')) return null;
|
||||
if (token.startsWith('t') && !token.startsWith('t:')) {
|
||||
return {
|
||||
messageType: 'ElapsedTime',
|
||||
value: parseInt(token.substring(1)) * minimalMoveTime,
|
||||
channel: 0,
|
||||
note: 0,
|
||||
velocity: 0,
|
||||
control: 0,
|
||||
instrument: 0
|
||||
};
|
||||
}
|
||||
const instrument: InstrumentType = InstrumentTypeTokenMap.findIndex(t => token.startsWith(t + ':'));
|
||||
if (instrument >= 0) {
|
||||
const parts = token.split(':');
|
||||
if (parts.length !== 3) return null;
|
||||
const note = parseInt(parts[1], 16);
|
||||
const velocity = parseInt(parts[2], 16);
|
||||
if (velocity < 0 || velocity > 127) return null;
|
||||
if (velocity === 0) return {
|
||||
messageType: 'NoteOff',
|
||||
note: note,
|
||||
instrument: instrument,
|
||||
channel: 0,
|
||||
velocity: 0,
|
||||
control: 0,
|
||||
value: 0
|
||||
};
|
||||
return {
|
||||
messageType: 'NoteOn',
|
||||
note: note,
|
||||
velocity: binToVelocity(velocity),
|
||||
instrument: instrument,
|
||||
channel: 0,
|
||||
control: 0,
|
||||
value: 0
|
||||
} as MidiMessage;
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const midiMessageToToken = (msg: MidiMessage) => {
|
||||
if (msg.messageType === 'NoteOn') {
|
||||
const instrument = InstrumentTypeTokenMap[commonStore.instrumentType];
|
||||
if (msg.messageType === 'NoteOn' || msg.messageType === 'NoteOff') {
|
||||
const instrument = InstrumentTypeTokenMap[msg.instrument];
|
||||
const note = msg.note.toString(16);
|
||||
const velocity = velocityToBin(msg.velocity).toString(16);
|
||||
return `${instrument}:${note}:${velocity} `;
|
||||
@@ -116,7 +133,6 @@ let dropRecordingTime = false;
|
||||
export const midiMessageHandler = async (data: MidiMessage) => {
|
||||
if (data.messageType === 'ControlChange') {
|
||||
commonStore.setInstrumentType(Math.round(data.value / 127 * (InstrumentTypeNameMap.length - 1)));
|
||||
displayCurrentInstrumentType();
|
||||
return;
|
||||
}
|
||||
if (commonStore.recordingTrackId) {
|
||||
@@ -136,6 +152,14 @@ export const midiMessageHandler = async (data: MidiMessage) => {
|
||||
}
|
||||
};
|
||||
|
||||
type TrackProps = {
|
||||
id: string;
|
||||
right: number;
|
||||
scale: number;
|
||||
isSelected: boolean;
|
||||
onSelect: (id: string) => void;
|
||||
};
|
||||
|
||||
const Track: React.FC<TrackProps> = observer(({
|
||||
id,
|
||||
right,
|
||||
@@ -146,9 +170,15 @@ const Track: React.FC<TrackProps> = observer(({
|
||||
const { t } = useTranslation();
|
||||
const trackIndex = commonStore.tracks.findIndex(t => t.id === id)!;
|
||||
const track = commonStore.tracks[trackIndex];
|
||||
const trackClass = isSelected ? 'bg-blue-600' : 'bg-gray-700';
|
||||
const trackClass = isSelected ? 'bg-blue-600' : (commonStore.settings.darkMode ? 'bg-blue-900' : 'bg-gray-700');
|
||||
const controlX = useRef(0);
|
||||
|
||||
let trackName = t('Track') + ' ' + id;
|
||||
if (track.mainInstrument)
|
||||
trackName = t('Track') + ' - ' + t('Piano is the main instrument')!.replace(t('Piano')!, t(track.mainInstrument)) + (track.content && (' - ' + track.content));
|
||||
else if (track.content)
|
||||
trackName = t('Track') + ' - ' + track.content;
|
||||
|
||||
return (
|
||||
<Draggable
|
||||
axis="x"
|
||||
@@ -183,7 +213,7 @@ const Track: React.FC<TrackProps> = observer(({
|
||||
}}
|
||||
onClick={() => onSelect(id)}
|
||||
>
|
||||
<span className="text-white">{t('Track') + ' ' + (track.content || id)}</span>
|
||||
<span className="text-white">{trackName}</span>
|
||||
</div>
|
||||
</Draggable>
|
||||
);
|
||||
@@ -298,7 +328,8 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
}, 1);
|
||||
}}
|
||||
>
|
||||
<div ref={currentTimeControlRef} className="h-2 bg-gray-700 cursor-move rounded"
|
||||
<div ref={currentTimeControlRef}
|
||||
className={classnames('h-2 cursor-move rounded', commonStore.settings.darkMode ? 'bg-neutral-600' : 'bg-gray-700')}
|
||||
style={{ width: currentTimeControlWidth }} />
|
||||
</Draggable>
|
||||
<div className={classnames(
|
||||
@@ -329,7 +360,8 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
<div className="relative cursor-move"
|
||||
ref={playStartTimeControlRef}>
|
||||
<ArrowAutofitWidth20Regular />
|
||||
<div className="border-l absolute border-gray-700"
|
||||
<div
|
||||
className={classnames('border-l absolute', commonStore.settings.darkMode ? 'border-white' : 'border-gray-700')}
|
||||
style={{
|
||||
height: (tracksRef.current && commonStore.tracks.length > 0)
|
||||
? tracksRef.current.clientHeight
|
||||
@@ -356,6 +388,7 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
<div key={track.id} className="flex gap-2 pb-1 border-b">
|
||||
<div className="flex gap-1 border-r h-7">
|
||||
<ToolTipButton desc={commonStore.recordingTrackId === track.id ? t('Stop') : t('Record')}
|
||||
disabled={commonStore.platform === 'web'}
|
||||
icon={commonStore.recordingTrackId === track.id ? <Stop16Filled /> : <Record16Regular />}
|
||||
size="small" shape="circular" appearance="subtle"
|
||||
onClick={() => {
|
||||
@@ -365,6 +398,11 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
if (commonStore.recordingTrackId === track.id) {
|
||||
commonStore.setRecordingTrackId('');
|
||||
} else {
|
||||
if (commonStore.activeMidiDeviceIndex === -1) {
|
||||
toast(t('Please select a MIDI device first'), { type: 'warning' });
|
||||
return;
|
||||
}
|
||||
|
||||
dropRecordingTime = true;
|
||||
setSelectedTrackId(track.id);
|
||||
|
||||
@@ -393,6 +431,7 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
appearance="subtle" onClick={() => {
|
||||
const tracks = commonStore.tracks.slice().filter(t => t.id !== track.id);
|
||||
commonStore.setTracks(tracks);
|
||||
refreshTracksTotalTime();
|
||||
}} />
|
||||
</div>
|
||||
<div className="relative grow overflow-hidden">
|
||||
@@ -408,19 +447,76 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
</div>
|
||||
</div>)}
|
||||
<div className="flex justify-between items-center">
|
||||
<Button icon={<Add16Regular />} size="small" shape="circular"
|
||||
appearance="subtle"
|
||||
onClick={() => {
|
||||
commonStore.setTracks([...commonStore.tracks, {
|
||||
id: uuid(),
|
||||
content: '',
|
||||
rawContent: [],
|
||||
offsetTime: 0,
|
||||
contentTime: 0
|
||||
}]);
|
||||
}}>
|
||||
{t('New Track')}
|
||||
</Button>
|
||||
<div className="flex gap-1">
|
||||
<Button icon={<Add16Regular />} size="small" shape="circular"
|
||||
appearance="subtle"
|
||||
disabled={commonStore.platform === 'web'}
|
||||
onClick={() => {
|
||||
commonStore.setTracks([...commonStore.tracks, {
|
||||
id: uuid(),
|
||||
mainInstrument: '',
|
||||
content: '',
|
||||
rawContent: [],
|
||||
offsetTime: 0,
|
||||
contentTime: 0
|
||||
}]);
|
||||
}}>
|
||||
{t('New Track')}
|
||||
</Button>
|
||||
<Button icon={<ArrowUpload16Regular />} size="small" shape="circular"
|
||||
appearance="subtle"
|
||||
onClick={() => {
|
||||
if (commonStore.status.status === ModelStatus.Offline && !commonStore.settings.apiUrl && commonStore.platform !== 'web') {
|
||||
toast(t('Please click the button in the top right corner to start the model'), { type: 'warning' });
|
||||
return;
|
||||
}
|
||||
|
||||
OpenOpenFileDialog('*.mid').then(async filePath => {
|
||||
if (!filePath)
|
||||
return;
|
||||
|
||||
let blob: Blob;
|
||||
if (commonStore.platform === 'web')
|
||||
blob = (filePath as unknown as { blob: Blob }).blob;
|
||||
else
|
||||
blob = await fetch(absPathAsset(filePath)).then(r => r.blob());
|
||||
const bodyForm = new FormData();
|
||||
bodyForm.append('file_data', blob);
|
||||
fetch(getServerRoot(commonStore.getCurrentModelConfig().apiParameters.apiPort) + '/midi-to-text', {
|
||||
method: 'POST',
|
||||
body: bodyForm
|
||||
}).then(async r => {
|
||||
if (r.status === 200) {
|
||||
const text = (await r.json()).text as string;
|
||||
const rawContent = text.split(' ').map(tokenToMidiMessage).filter(m => m) as MidiMessage[];
|
||||
const tracks = commonStore.tracks.slice();
|
||||
|
||||
tracks.push({
|
||||
id: uuid(),
|
||||
mainInstrument: getMidiRawContentMainInstrument(rawContent),
|
||||
content: text,
|
||||
rawContent: rawContent,
|
||||
offsetTime: 0,
|
||||
contentTime: getMidiRawContentTime(rawContent)
|
||||
});
|
||||
commonStore.setTracks(tracks);
|
||||
refreshTracksTotalTime();
|
||||
} else {
|
||||
toast(r.statusText + '\n' + (await r.text()), {
|
||||
type: 'error'
|
||||
});
|
||||
}
|
||||
}
|
||||
).catch(e => {
|
||||
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
|
||||
});
|
||||
}).catch(e => {
|
||||
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
|
||||
});
|
||||
}}>
|
||||
{t('Import MIDI')}
|
||||
</Button>
|
||||
</div>
|
||||
<Text size={100}>
|
||||
{t('Select a track to preview the content')}
|
||||
</Text>
|
||||
@@ -431,7 +527,7 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
<Card size="small" appearance="outline" style={{ minHeight: '150px', maxHeight: '200px' }}>
|
||||
<div className="flex flex-col gap-1 overflow-hidden">
|
||||
<Text size={100}>{`${t('Start Time')}: ${selectedTrack.offsetTime} ms`}</Text>
|
||||
<Text size={100}>{`${t('Content Time')}: ${selectedTrack.contentTime} ms`}</Text>
|
||||
<Text size={100}>{`${t('Content Duration')}: ${selectedTrack.contentTime} ms`}</Text>
|
||||
<div className="overflow-y-auto overflow-x-hidden" ref={contentPreviewRef}>
|
||||
{selectedTrackId === commonStore.recordingTrackId
|
||||
? commonStore.recordingContent
|
||||
@@ -440,6 +536,18 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
</div>
|
||||
</Card>
|
||||
}
|
||||
{
|
||||
commonStore.platform !== 'web' &&
|
||||
<div className="flex gap-2 items-end mx-auto">
|
||||
{t('Current Instrument') + ':'}
|
||||
{InstrumentTypeNameMap.map((name, i) =>
|
||||
<Text key={name} style={{ whiteSpace: 'nowrap' }}
|
||||
className={commonStore.instrumentType === i ? 'text-blue-600' : ''}
|
||||
weight={commonStore.instrumentType === i ? 'bold' : 'regular'}
|
||||
size={commonStore.instrumentType === i ? 300 : 100}
|
||||
>{t(name)}</Text>)}
|
||||
</div>
|
||||
}
|
||||
<DialogTrigger disableButtonEnhancement>
|
||||
<Button icon={<MusicNote220Regular />} style={{ minHeight: '32px' }} onClick={() => {
|
||||
flushMidiRecordingContent();
|
||||
@@ -472,14 +580,14 @@ const AudiotrackEditor: FC<{ setPrompt: (prompt: string) => void }> = observer((
|
||||
if (msg.messageType === 'ElapsedTime') {
|
||||
accContentTime += msg.value;
|
||||
currentTime = track.offsetTime + accContentTime;
|
||||
} else if (msg.messageType === 'NoteOn') {
|
||||
} else if (msg.messageType === 'NoteOn' || msg.messageType === 'NoteOff') {
|
||||
const insertIndex = sortedTimestamp.findIndex(t => t >= currentTime);
|
||||
globalMessages.splice(insertIndex + 1, 0, msg);
|
||||
sortedTimestamp.splice(insertIndex + 1, 0, 0); // placeholder
|
||||
}
|
||||
}
|
||||
}
|
||||
const result = globalMessages.map(m => midiMessageToToken(m)).join('');
|
||||
const result = ('<pad> ' + globalMessages.map(midiMessageToToken).join('')).trim();
|
||||
commonStore.setCompositionSubmittedPrompt(result);
|
||||
setPrompt(result);
|
||||
}}>
|
||||
|
||||
@@ -28,7 +28,7 @@ import { toast } from 'react-toastify';
|
||||
import { WorkHeader } from '../components/WorkHeader';
|
||||
import { DialogButton } from '../components/DialogButton';
|
||||
import { OpenFileFolder, OpenOpenFileDialog, OpenSaveFileDialog } from '../../wailsjs/go/backend_golang/App';
|
||||
import { absPathAsset, bytesToReadable, getServerRoot, toastWithButton } from '../utils';
|
||||
import { absPathAsset, bytesToReadable, getServerRoot, setActivePreset, toastWithButton } from '../utils';
|
||||
import { useMediaQuery } from 'usehooks-ts';
|
||||
import { botName, ConversationMessage, MessageType, userName, welcomeUuid } from '../types/chat';
|
||||
import { Labeled } from '../components/Labeled';
|
||||
@@ -436,7 +436,7 @@ const ChatPanel: FC = observer(() => {
|
||||
const chatSseController = new AbortController();
|
||||
chatSseControllers[answerId] = chatSseController;
|
||||
fetchEventSource( // https://api.openai.com/v1/chat/completions || http://127.0.0.1:${port}/v1/chat/completions
|
||||
getServerRoot(port) + '/v1/chat/completions',
|
||||
getServerRoot(port, true) + '/v1/chat/completions',
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -536,8 +536,7 @@ const ChatPanel: FC = observer(() => {
|
||||
}
|
||||
chatSseControllers = {};
|
||||
}
|
||||
commonStore.setConversation({});
|
||||
commonStore.setConversationOrder([]);
|
||||
setActivePreset(commonStore.activePreset);
|
||||
}} />
|
||||
<div className="relative flex grow">
|
||||
<Textarea
|
||||
@@ -554,7 +553,7 @@ const ChatPanel: FC = observer(() => {
|
||||
{!commonStore.currentTempAttachment ?
|
||||
<ToolTipButton
|
||||
desc={commonStore.attachmentUploading ?
|
||||
t('Uploading Attachment') :
|
||||
t('Processing Attachment') :
|
||||
t('Add An Attachment (Accepts pdf, txt)')}
|
||||
icon={commonStore.attachmentUploading ?
|
||||
<ArrowClockwise16Regular className="animate-spin" />
|
||||
@@ -568,7 +567,7 @@ const ChatPanel: FC = observer(() => {
|
||||
const setUploading = () => commonStore.setAttachmentUploading(true);
|
||||
// actually, status of web platform is always Offline
|
||||
if (commonStore.platform === 'web' || commonStore.status.status === ModelStatus.Offline || currentConfig.modelParameters.device === 'WebGPU') {
|
||||
webOpenOpenFileDialog({ filterPattern, fnStartLoading: setUploading }).then(webReturn => {
|
||||
webOpenOpenFileDialog(filterPattern, setUploading).then(webReturn => {
|
||||
if (webReturn.content)
|
||||
commonStore.setCurrentTempAttachment(
|
||||
{
|
||||
|
||||
@@ -82,7 +82,7 @@ const CompletionPanel: FC = observer(() => {
|
||||
let answer = '';
|
||||
completionSseController = new AbortController();
|
||||
fetchEventSource( // https://api.openai.com/v1/completions || http://127.0.0.1:${port}/v1/completions
|
||||
getServerRoot(port) + '/v1/completions',
|
||||
getServerRoot(port, true) + '/v1/completions',
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
||||
@@ -15,13 +15,16 @@ import { ArrowSync20Regular, Save28Regular } from '@fluentui/react-icons';
|
||||
import { PlayerElement, VisualizerElement } from 'html-midi-player';
|
||||
import * as mm from '@magenta/music/esm/core.js';
|
||||
import { NoteSequence } from '@magenta/music/esm/protobuf.js';
|
||||
import { defaultCompositionPrompt } from './defaultConfigs';
|
||||
import { defaultCompositionABCPrompt, defaultCompositionPrompt } from './defaultConfigs';
|
||||
import {
|
||||
CloseMidiPort,
|
||||
FileExists,
|
||||
OpenFileFolder,
|
||||
OpenMidiPort,
|
||||
OpenSaveFileDialogBytes
|
||||
OpenSaveFileDialog,
|
||||
OpenSaveFileDialogBytes,
|
||||
SaveFile,
|
||||
StartFile
|
||||
} from '../../wailsjs/go/backend_golang/App';
|
||||
import { getServerRoot, getSoundFont, toastWithButton } from '../utils';
|
||||
import { CompositionParams } from '../types/composition';
|
||||
@@ -34,7 +37,9 @@ const CompositionPanel: FC = observer(() => {
|
||||
const { t } = useTranslation();
|
||||
const mq = useMediaQuery('(min-width: 640px)');
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
const port = commonStore.getCurrentModelConfig().apiParameters.apiPort;
|
||||
const modelConfig = commonStore.getCurrentModelConfig();
|
||||
const port = modelConfig.apiParameters.apiPort;
|
||||
const isABC = modelConfig.modelParameters.modelName.toLowerCase().includes('abc');
|
||||
const visualizerRef = useRef<VisualizerElement>(null);
|
||||
const playerRef = useRef<PlayerElement>(null);
|
||||
|
||||
@@ -98,7 +103,46 @@ const CompositionPanel: FC = observer(() => {
|
||||
}
|
||||
}, []);
|
||||
|
||||
const externalPlayListener = () => {
|
||||
const params = commonStore.compositionParams;
|
||||
const saveAndPlay = async (midi: ArrayBuffer, path: string) => {
|
||||
await SaveFile(path, Array.from(new Uint8Array(midi)));
|
||||
StartFile(path);
|
||||
};
|
||||
if (params.externalPlay) {
|
||||
if (params.midi) {
|
||||
setTimeout(() => {
|
||||
playerRef.current?.stop();
|
||||
});
|
||||
saveAndPlay(params.midi, './midi/last.mid').catch((e: string) => {
|
||||
if (e.includes('being used'))
|
||||
saveAndPlay(params.midi!, './midi/last-2.mid');
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
useEffect(() => {
|
||||
playerRef.current?.addEventListener('start', externalPlayListener);
|
||||
return () => {
|
||||
playerRef.current?.removeEventListener('start', externalPlayListener);
|
||||
};
|
||||
}, [params.externalPlay]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!(commonStore.activeMidiDeviceIndex in commonStore.midiPorts)) {
|
||||
commonStore.setActiveMidiDeviceIndex(-1);
|
||||
CloseMidiPort();
|
||||
}
|
||||
}, [commonStore.midiPorts]);
|
||||
|
||||
const generateNs = (autoPlay: boolean) => {
|
||||
if (commonStore.getCurrentModelConfig().modelParameters.modelName.toLowerCase().includes('abc')) {
|
||||
import('abcjs').then(ABCJS => {
|
||||
ABCJS.renderAbc('abc-paper', commonStore.compositionParams.prompt, { responsive: 'resize' });
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
fetch(getServerRoot(port) + '/text-to-midi', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -116,9 +160,16 @@ const CompositionPanel: FC = observer(() => {
|
||||
});
|
||||
updateNs(ns);
|
||||
if (autoPlay) {
|
||||
setTimeout(() => {
|
||||
playerRef.current?.start();
|
||||
});
|
||||
if (commonStore.compositionParams.externalPlay)
|
||||
externalPlayListener();
|
||||
else {
|
||||
if (commonStore.compositionParams.playOnlyGeneratedContent && playerRef.current) {
|
||||
playerRef.current.currentTime = Math.max(commonStore.compositionParams.generationStartTime - 1, 0);
|
||||
}
|
||||
setTimeout(() => {
|
||||
playerRef.current?.start();
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -136,7 +187,7 @@ const CompositionPanel: FC = observer(() => {
|
||||
let answer = '';
|
||||
compositionSseController = new AbortController();
|
||||
fetchEventSource( // https://api.openai.com/v1/completions || http://127.0.0.1:${port}/v1/completions
|
||||
getServerRoot(port) + '/v1/completions',
|
||||
getServerRoot(port, true) + '/v1/completions',
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -243,34 +294,58 @@ const CompositionPanel: FC = observer(() => {
|
||||
}} />
|
||||
} />
|
||||
<div className="grow" />
|
||||
<Checkbox className="select-none"
|
||||
size="large" label={t('Use Local Sound Font')} checked={params.useLocalSoundFont}
|
||||
onChange={async (_, data) => {
|
||||
if (data.checked) {
|
||||
if (!await FileExists('assets/sound-font/accordion/instrument.json')) {
|
||||
toast(t('Failed to load local sound font, please check if the files exist - assets/sound-font'),
|
||||
{ type: 'warning' });
|
||||
return;
|
||||
{
|
||||
commonStore.platform !== 'web' &&
|
||||
<Checkbox className="select-none"
|
||||
size="large" label={t('Use Local Sound Font')} checked={params.useLocalSoundFont}
|
||||
onChange={async (_, data) => {
|
||||
if (data.checked) {
|
||||
if (!await FileExists('assets/sound-font/accordion/instrument.json')) {
|
||||
toast(t('Failed to load local sound font, please check if the files exist - assets/sound-font'),
|
||||
{ type: 'warning' });
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
setParams({
|
||||
useLocalSoundFont: data.checked as boolean
|
||||
});
|
||||
setSoundFont();
|
||||
}} />
|
||||
setParams({
|
||||
useLocalSoundFont: data.checked as boolean
|
||||
});
|
||||
setSoundFont();
|
||||
}} />
|
||||
}
|
||||
{
|
||||
commonStore.platform === 'windows' &&
|
||||
<Checkbox className="select-none"
|
||||
size="large" label={t('Play With External Player')} checked={params.externalPlay}
|
||||
onChange={async (_, data) => {
|
||||
setParams({
|
||||
externalPlay: data.checked as boolean
|
||||
});
|
||||
}} />
|
||||
}
|
||||
<Checkbox className="select-none"
|
||||
size="large" label={t('Auto Play At The End')} checked={params.autoPlay} onChange={(_, data) => {
|
||||
setParams({
|
||||
autoPlay: data.checked as boolean
|
||||
});
|
||||
}} />
|
||||
{commonStore.platform !== 'web' &&
|
||||
<Labeled flex breakline label={t('MIDI Input')}
|
||||
desc={t('Select the MIDI input device to be used.')}
|
||||
content={
|
||||
<div className="flex flex-col gap-1">
|
||||
<Checkbox className="select-none"
|
||||
size="large" label={t('Only Auto Play Generated Content')} checked={params.playOnlyGeneratedContent}
|
||||
onChange={async (_, data) => {
|
||||
setParams({
|
||||
autoPlay: data.checked as boolean || commonStore.compositionParams.autoPlay,
|
||||
playOnlyGeneratedContent: data.checked as boolean
|
||||
});
|
||||
}} />
|
||||
<Labeled flex breakline label={t('MIDI Input')}
|
||||
desc={t('Select the MIDI input device to be used.')}
|
||||
content={
|
||||
<div className="flex flex-col gap-1">
|
||||
{
|
||||
commonStore.platform !== 'web' &&
|
||||
<Dropdown style={{ minWidth: 0 }}
|
||||
value={commonStore.activeMidiDeviceIndex === -1 ? t('None')! : commonStore.midiPorts[commonStore.activeMidiDeviceIndex].name}
|
||||
value={(commonStore.activeMidiDeviceIndex === -1 || !(commonStore.activeMidiDeviceIndex in commonStore.midiPorts))
|
||||
? t('None')!
|
||||
: commonStore.midiPorts[commonStore.activeMidiDeviceIndex].name}
|
||||
selectedOptions={[commonStore.activeMidiDeviceIndex.toString()]}
|
||||
onOptionSelect={(_, data) => {
|
||||
if (data.optionValue) {
|
||||
@@ -290,10 +365,10 @@ const CompositionPanel: FC = observer(() => {
|
||||
<Option key={i} value={i.toString()}>{p.name}</Option>)
|
||||
}
|
||||
</Dropdown>
|
||||
<AudiotrackButton setPrompt={setPrompt} />
|
||||
</div>
|
||||
} />
|
||||
}
|
||||
}
|
||||
<AudiotrackButton setPrompt={setPrompt} />
|
||||
</div>
|
||||
} />
|
||||
</div>
|
||||
<div className="flex justify-between gap-2">
|
||||
<ToolTipButton desc={t('Regenerate')} icon={<ArrowSync20Regular />} onClick={() => {
|
||||
@@ -305,8 +380,13 @@ const CompositionPanel: FC = observer(() => {
|
||||
<DialogButton className="grow" text={t('Reset')} title={t('Reset')}
|
||||
contentText={t('Are you sure you want to reset this page? It cannot be undone.')}
|
||||
onConfirm={() => {
|
||||
commonStore.setCompositionSubmittedPrompt(defaultCompositionPrompt);
|
||||
setPrompt(defaultCompositionPrompt);
|
||||
const isABC = commonStore.getCurrentModelConfig().modelParameters.modelName.toLowerCase().includes('abc');
|
||||
const defaultPrompt = isABC ? defaultCompositionABCPrompt : defaultCompositionPrompt;
|
||||
commonStore.setCompositionSubmittedPrompt(defaultPrompt);
|
||||
setParams({
|
||||
generationStartTime: 0
|
||||
});
|
||||
setPrompt(defaultPrompt);
|
||||
}} />
|
||||
<Button className="grow" appearance="primary" onClick={() => {
|
||||
if (commonStore.compositionGenerating) {
|
||||
@@ -315,6 +395,9 @@ const CompositionPanel: FC = observer(() => {
|
||||
generateNs(params.autoPlay);
|
||||
} else {
|
||||
commonStore.setCompositionGenerating(true);
|
||||
setParams({
|
||||
generationStartTime: playerRef.current ? playerRef.current.duration : 0
|
||||
});
|
||||
onSubmit(params.prompt);
|
||||
}
|
||||
}}>{!commonStore.compositionGenerating ? t('Generate') : t('Stop')}</Button>
|
||||
@@ -323,18 +406,33 @@ const CompositionPanel: FC = observer(() => {
|
||||
</div>
|
||||
<div className="flex flex-col">
|
||||
<div className="ml-auto mr-auto">
|
||||
<midi-visualizer
|
||||
ref={visualizerRef}
|
||||
type="waterfall"
|
||||
/>
|
||||
{isABC ? <div /> :
|
||||
<midi-visualizer
|
||||
ref={visualizerRef}
|
||||
type="waterfall"
|
||||
/>}
|
||||
</div>
|
||||
<div className="flex">
|
||||
<midi-player
|
||||
ref={playerRef}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
{isABC ? <div className="flex flex-col overflow-y-auto grow m-1" style={{ maxHeight: '260px' }}>
|
||||
<div id="abc-paper" />
|
||||
</div> :
|
||||
<midi-player
|
||||
ref={playerRef}
|
||||
style={{ width: '100%' }}
|
||||
/>}
|
||||
<Button icon={<Save28Regular />} size={mq ? 'large' : 'medium'} appearance={mq ? 'secondary' : 'subtle'}
|
||||
onClick={() => {
|
||||
if (isABC) {
|
||||
OpenSaveFileDialog('*.txt', 'abc-music.txt', commonStore.compositionParams.prompt).then((path) => {
|
||||
if (path)
|
||||
toastWithButton(t('File Saved'), t('Open'), () => {
|
||||
OpenFileFolder(path, false);
|
||||
});
|
||||
}).catch((e) => {
|
||||
toast(t('Error') + ' - ' + (e.message || e), { type: 'error', autoClose: 2500 });
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (params.midi) {
|
||||
OpenSaveFileDialogBytes('*.mid', 'music.mid', Array.from(new Uint8Array(params.midi))).then((path) => {
|
||||
if (path)
|
||||
|
||||
@@ -8,12 +8,13 @@ import {
|
||||
Input,
|
||||
Label,
|
||||
Option,
|
||||
PresenceBadge,
|
||||
Select,
|
||||
Switch,
|
||||
Text
|
||||
} from '@fluentui/react-components';
|
||||
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
|
||||
import React, { FC, useEffect, useRef } from 'react';
|
||||
import React, { FC, useCallback, useEffect, useRef } from 'react';
|
||||
import { Section } from '../components/Section';
|
||||
import { Labeled } from '../components/Labeled';
|
||||
import { ToolTipButton } from '../components/ToolTipButton';
|
||||
@@ -26,16 +27,39 @@ import { Page } from '../components/Page';
|
||||
import { useNavigate } from 'react-router';
|
||||
import { RunButton } from '../components/RunButton';
|
||||
import { updateConfig } from '../apis';
|
||||
import { ConvertModel, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
|
||||
import { checkDependencies, getStrategy } from '../utils';
|
||||
import { getStrategy } from '../utils';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { WindowShow } from '../../wailsjs/runtime';
|
||||
import strategyImg from '../assets/images/strategy.jpg';
|
||||
import strategyZhImg from '../assets/images/strategy_zh.jpg';
|
||||
import { ResetConfigsButton } from '../components/ResetConfigsButton';
|
||||
import { useMediaQuery } from 'usehooks-ts';
|
||||
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
|
||||
import { convertToSt } from '../utils/convert-to-st';
|
||||
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
|
||||
|
||||
const ConfigSelector: FC<{
|
||||
selectedIndex: number,
|
||||
updateSelectedIndex: (i: number) => void
|
||||
}> = observer(({ selectedIndex, updateSelectedIndex }) => {
|
||||
return (
|
||||
<Dropdown style={{ minWidth: 0 }} className="grow" value={commonStore.modelConfigs[selectedIndex].name}
|
||||
selectedOptions={[selectedIndex.toString()]}
|
||||
onOptionSelect={(_, data) => {
|
||||
if (data.optionValue) {
|
||||
updateSelectedIndex(Number(data.optionValue));
|
||||
}
|
||||
}}>
|
||||
{commonStore.modelConfigs.map((config, index) => <Option key={index} value={index.toString()}
|
||||
text={config.name}>
|
||||
<div className="flex justify-between grow">
|
||||
{config.name}
|
||||
{commonStore.modelSourceList.find(item => item.name === config.modelParameters.modelName)?.isComplete
|
||||
&& <PresenceBadge status="available" />}
|
||||
</div>
|
||||
</Option>
|
||||
)}
|
||||
</Dropdown>
|
||||
);
|
||||
});
|
||||
|
||||
const Configs: FC = observer(() => {
|
||||
const { t } = useTranslation();
|
||||
@@ -52,13 +76,13 @@ const Configs: FC = observer(() => {
|
||||
(advancedHeaderRef.current.firstElementChild as HTMLElement).style.padding = '0';
|
||||
}, []);
|
||||
|
||||
const updateSelectedIndex = (newIndex: number) => {
|
||||
const updateSelectedIndex = useCallback((newIndex: number) => {
|
||||
setSelectedIndex(newIndex);
|
||||
setSelectedConfig(commonStore.modelConfigs[newIndex]);
|
||||
|
||||
// if you don't want to update the config used by the current startup in real time, comment out this line
|
||||
commonStore.setCurrentConfigIndex(newIndex);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const setSelectedConfigName = (newName: string) => {
|
||||
setSelectedConfig({ ...selectedConfig, name: newName });
|
||||
@@ -98,17 +122,7 @@ const Configs: FC = observer(() => {
|
||||
<Page title={t('Configs')} content={
|
||||
<div className="flex flex-col gap-2 overflow-hidden">
|
||||
<div className="flex gap-2 items-center">
|
||||
<Dropdown style={{ minWidth: 0 }} className="grow" value={commonStore.modelConfigs[selectedIndex].name}
|
||||
selectedOptions={[selectedIndex.toString()]}
|
||||
onOptionSelect={(_, data) => {
|
||||
if (data.optionValue) {
|
||||
updateSelectedIndex(Number(data.optionValue));
|
||||
}
|
||||
}}>
|
||||
{commonStore.modelConfigs.map((config, index) =>
|
||||
<Option key={index} value={index.toString()}>{config.name}</Option>
|
||||
)}
|
||||
</Dropdown>
|
||||
<ConfigSelector selectedIndex={selectedIndex} updateSelectedIndex={updateSelectedIndex} />
|
||||
<ToolTipButton desc={t('New Config')} icon={<AddCircle20Regular />} onClick={() => {
|
||||
commonStore.createModelConfig();
|
||||
updateSelectedIndex(commonStore.modelConfigs.length - 1);
|
||||
@@ -214,9 +228,18 @@ const Configs: FC = observer(() => {
|
||||
<Select style={{ minWidth: 0 }} className="grow"
|
||||
value={selectedConfig.modelParameters.modelName}
|
||||
onChange={(e, data) => {
|
||||
setSelectedConfigModelParams({
|
||||
modelName: data.value
|
||||
});
|
||||
const modelSource = commonStore.modelSourceList.find(item => item.name === data.value);
|
||||
if (modelSource?.customTokenizer)
|
||||
setSelectedConfigModelParams({
|
||||
modelName: data.value,
|
||||
useCustomTokenizer: true,
|
||||
customTokenizer: modelSource?.customTokenizer
|
||||
});
|
||||
else // prevent customTokenizer from being overwritten
|
||||
setSelectedConfigModelParams({
|
||||
modelName: data.value,
|
||||
useCustomTokenizer: false
|
||||
});
|
||||
}}>
|
||||
{!commonStore.modelSourceList.find(item => item.name === selectedConfig.modelParameters.modelName)?.isComplete
|
||||
&& <option key={-1}
|
||||
@@ -232,48 +255,17 @@ const Configs: FC = observer(() => {
|
||||
</div>
|
||||
} />
|
||||
{
|
||||
selectedConfig.modelParameters.device !== 'WebGPU' ?
|
||||
<ToolTipButton text={t('Convert')}
|
||||
desc={t('Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.')}
|
||||
onClick={async () => {
|
||||
if (commonStore.platform === 'darwin') {
|
||||
toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||
return;
|
||||
} else if (commonStore.platform === 'linux') {
|
||||
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||
return;
|
||||
}
|
||||
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
const strategy = getStrategy(selectedConfig);
|
||||
const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
|
||||
toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
|
||||
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
|
||||
if (!await FileExists(newModelPath + '.pth')) {
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
}} /> :
|
||||
<ToolTipButton text={t('Convert To Safe Tensors Format')}
|
||||
!selectedConfig.modelParameters.device.startsWith('WebGPU') ?
|
||||
(selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' ?
|
||||
<ToolTipButton text={t('Convert')}
|
||||
desc={t('Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.')}
|
||||
onClick={() => convertModel(selectedConfig, navigate)} /> :
|
||||
<ToolTipButton text={t('Convert To GGML Format')}
|
||||
desc=""
|
||||
onClick={() => convertToGGML(selectedConfig, navigate)} />)
|
||||
: <ToolTipButton text={t('Convert To Safe Tensors Format')}
|
||||
desc=""
|
||||
onClick={() => convertToSt(navigate, selectedConfig)} />
|
||||
onClick={() => convertToSt(selectedConfig, navigate)} />
|
||||
}
|
||||
<Labeled label={t('Strategy')} content={
|
||||
<Dropdown style={{ minWidth: 0 }} className="grow" value={t(selectedConfig.modelParameters.device)!}
|
||||
@@ -286,10 +278,12 @@ const Configs: FC = observer(() => {
|
||||
}
|
||||
}}>
|
||||
<Option value="CPU">CPU</Option>
|
||||
<Option value="CPU (rwkv.cpp)">{t('CPU (rwkv.cpp, Faster)')!}</Option>
|
||||
{commonStore.platform === 'darwin' && <Option value="MPS">MPS</Option>}
|
||||
<Option value="CUDA">CUDA</Option>
|
||||
<Option value="CUDA-Beta">{t('CUDA (Beta, Faster)')!}</Option>
|
||||
<Option value="WebGPU">WebGPU</Option>
|
||||
<Option value="WebGPU (Python)">WebGPU (Python)</Option>
|
||||
<Option value="Custom">{t('Custom')!}</Option>
|
||||
</Dropdown>
|
||||
} />
|
||||
@@ -297,7 +291,8 @@ const Configs: FC = observer(() => {
|
||||
selectedConfig.modelParameters.device !== 'Custom' && <Labeled label={t('Precision')}
|
||||
desc={t('int8 uses less VRAM, but has slightly lower quality. fp16 has higher quality.')}
|
||||
content={
|
||||
<Dropdown style={{ minWidth: 0 }} className="grow"
|
||||
<Dropdown
|
||||
style={{ minWidth: 0 }} className="grow"
|
||||
value={selectedConfig.modelParameters.precision}
|
||||
selectedOptions={[selectedConfig.modelParameters.precision]}
|
||||
onOptionSelect={(_, data) => {
|
||||
@@ -309,19 +304,21 @@ const Configs: FC = observer(() => {
|
||||
}}>
|
||||
{selectedConfig.modelParameters.device !== 'CPU' && selectedConfig.modelParameters.device !== 'MPS' &&
|
||||
<Option>fp16</Option>}
|
||||
<Option>int8</Option>
|
||||
{selectedConfig.modelParameters.device === 'WebGPU' && <Option>nf4</Option>}
|
||||
{selectedConfig.modelParameters.device !== 'WebGPU' && <Option>fp32</Option>}
|
||||
{selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && <Option>int8</Option>}
|
||||
{selectedConfig.modelParameters.device.startsWith('WebGPU') && <Option>nf4</Option>}
|
||||
{selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && !selectedConfig.modelParameters.device.startsWith('WebGPU') &&
|
||||
<Option>fp32</Option>}
|
||||
{selectedConfig.modelParameters.device === 'CPU (rwkv.cpp)' && <Option>Q5_1</Option>}
|
||||
</Dropdown>
|
||||
} />
|
||||
}
|
||||
{
|
||||
selectedConfig.modelParameters.device.includes('CUDA') &&
|
||||
selectedConfig.modelParameters.device.startsWith('CUDA') &&
|
||||
<Labeled label={t('Current Strategy')}
|
||||
content={<Text> {getStrategy(selectedConfig)} </Text>} />
|
||||
}
|
||||
{
|
||||
selectedConfig.modelParameters.device.includes('CUDA') &&
|
||||
selectedConfig.modelParameters.device.startsWith('CUDA') &&
|
||||
<Labeled label={t('Stored Layers')}
|
||||
desc={t('Number of the neural network layers loaded into VRAM, the more you load, the faster the speed, but it consumes more VRAM. (If your VRAM is not enough, it will fail to load)')}
|
||||
content={
|
||||
@@ -334,7 +331,7 @@ const Configs: FC = observer(() => {
|
||||
}} />
|
||||
} />
|
||||
}
|
||||
{selectedConfig.modelParameters.device.includes('CUDA') && <div />}
|
||||
{selectedConfig.modelParameters.device.startsWith('CUDA') && <div />}
|
||||
{
|
||||
displayStrategyImg &&
|
||||
<img style={{ width: '80vh', height: 'auto', zIndex: 100 }}
|
||||
@@ -359,7 +356,7 @@ const Configs: FC = observer(() => {
|
||||
}
|
||||
{selectedConfig.modelParameters.device === 'Custom' && <div />}
|
||||
{
|
||||
(selectedConfig.modelParameters.device.includes('CUDA') || selectedConfig.modelParameters.device === 'Custom') &&
|
||||
(selectedConfig.modelParameters.device.startsWith('CUDA') || selectedConfig.modelParameters.device === 'Custom') &&
|
||||
<Labeled label={t('Use Custom CUDA kernel to Accelerate')}
|
||||
desc={t('Enabling this option can greatly improve inference speed and save some VRAM, but there may be compatibility issues (output garbled). If it fails to start, please turn off this option, or try to upgrade your gpu driver.')}
|
||||
content={
|
||||
@@ -392,7 +389,7 @@ const Configs: FC = observer(() => {
|
||||
});
|
||||
}} />
|
||||
<Input className="grow"
|
||||
placeholder={t('Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json)')!}
|
||||
placeholder={t('Tokenizer Path (e.g. backend-python/rwkv_pip/20B_tokenizer.json or rwkv_vocab_v20230424.txt)')!}
|
||||
value={selectedConfig.modelParameters.customTokenizer}
|
||||
onChange={(e, data) => {
|
||||
setSelectedConfigModelParams({
|
||||
@@ -408,6 +405,7 @@ const Configs: FC = observer(() => {
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
{mq && <div style={{ minHeight: '30px' }} />}
|
||||
</div>
|
||||
<div className="flex flex-row-reverse sm:fixed bottom-2 right-2">
|
||||
<div className="flex gap-2">
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import React, { FC } from 'react';
|
||||
import React, { FC, useEffect, useState } from 'react';
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
createTableColumn,
|
||||
DataGrid,
|
||||
@@ -152,8 +153,33 @@ const columns: TableColumnDefinition<ModelSourceItem>[] = [
|
||||
})
|
||||
];
|
||||
|
||||
const getTags = () => {
|
||||
return Array.from(new Set(
|
||||
['Recommended',
|
||||
...commonStore.modelSourceList.map(item => item.tags || []).flat()
|
||||
.filter(i => !i.includes('Other') && !i.includes('Local'))
|
||||
, 'Other', 'Local']));
|
||||
};
|
||||
|
||||
const getCurrentModelList = () => {
|
||||
if (commonStore.activeModelListTags.length === 0)
|
||||
return commonStore.modelSourceList;
|
||||
else
|
||||
return commonStore.modelSourceList.filter(item => commonStore.activeModelListTags.some(tag => item.tags?.includes(tag)));
|
||||
};
|
||||
|
||||
const Models: FC = observer(() => {
|
||||
const { t } = useTranslation();
|
||||
const [tags, setTags] = useState<Array<string>>(getTags());
|
||||
const [modelSourceList, setModelSourceList] = useState<ModelSourceItem[]>(getCurrentModelList());
|
||||
|
||||
useEffect(() => {
|
||||
setTags(getTags());
|
||||
}, [commonStore.modelSourceList]);
|
||||
|
||||
useEffect(() => {
|
||||
setModelSourceList(getCurrentModelList());
|
||||
}, [commonStore.modelSourceList, commonStore.activeModelListTags]);
|
||||
|
||||
return (
|
||||
<Page title={t('Models')} content={
|
||||
@@ -184,9 +210,24 @@ const Models: FC = observer(() => {
|
||||
value={commonStore.modelSourceManifestList}
|
||||
onChange={(e, data) => commonStore.setModelSourceManifestList(data.value)} />
|
||||
</div>
|
||||
<div className="flex gap-2 flex-wrap overflow-y-auto" style={{ minHeight: '88px' }}>
|
||||
{tags.map(tag =>
|
||||
<div key={tag} className="mt-auto">
|
||||
<Button
|
||||
appearance={commonStore.activeModelListTags.includes(tag) ? 'primary' : 'secondary'} onClick={
|
||||
() => {
|
||||
if (commonStore.activeModelListTags.includes(tag))
|
||||
commonStore.setActiveModelListTags(commonStore.activeModelListTags.filter(t => t !== tag));
|
||||
else
|
||||
commonStore.setActiveModelListTags([...commonStore.activeModelListTags, tag]);
|
||||
}
|
||||
}>{t(tag)}</Button>
|
||||
</div>)
|
||||
}
|
||||
</div>
|
||||
<div className="flex grow overflow-hidden">
|
||||
<DataGrid
|
||||
items={commonStore.modelSourceList}
|
||||
items={modelSourceList}
|
||||
columns={columns}
|
||||
sortable={true}
|
||||
defaultSortState={{ sortColumn: 'actions', sortDirection: 'ascending' }}
|
||||
|
||||
@@ -33,10 +33,8 @@ import { observer } from 'mobx-react-lite';
|
||||
import { ClipboardGetText, ClipboardSetText } from '../../../wailsjs/runtime';
|
||||
import { toast } from 'react-toastify';
|
||||
import { CustomToastContainer } from '../../components/CustomToastContainer';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import { absPathAsset } from '../../utils';
|
||||
import { absPathAsset, setActivePreset } from '../../utils';
|
||||
import { Preset, PresetsNavigationItem } from '../../types/presets';
|
||||
import { botName, Conversation, MessageType, userName } from '../../types/chat';
|
||||
import { LazyImportComponent } from '../../components/LazyImportComponent';
|
||||
|
||||
const defaultPreset: Preset = {
|
||||
@@ -52,35 +50,17 @@ const defaultPreset: Preset = {
|
||||
prompt: '',
|
||||
stop: '',
|
||||
injectStart: '',
|
||||
injectEnd: ''
|
||||
injectEnd: '',
|
||||
presystem: true,
|
||||
userName: '',
|
||||
assistantName: ''
|
||||
};
|
||||
|
||||
const MessagesEditor = lazy(() => import('./MessagesEditor'));
|
||||
|
||||
const setActivePreset = (preset: Preset) => {
|
||||
commonStore.setActivePreset(preset);
|
||||
//TODO if (preset.displayPresetMessages) {
|
||||
const conversation: Conversation = {};
|
||||
const conversationOrder: string[] = [];
|
||||
for (const message of preset.messages) {
|
||||
const newUuid = uuid();
|
||||
conversationOrder.push(newUuid);
|
||||
conversation[newUuid] = {
|
||||
sender: message.role === 'user' ? userName : botName,
|
||||
type: MessageType.Normal,
|
||||
color: message.role === 'user' ? 'brand' : 'colorful',
|
||||
time: new Date().toISOString(),
|
||||
content: message.content,
|
||||
side: message.role === 'user' ? 'right' : 'left',
|
||||
done: true
|
||||
};
|
||||
}
|
||||
commonStore.setConversation(conversation);
|
||||
commonStore.setConversationOrder(conversationOrder);
|
||||
//}
|
||||
};
|
||||
|
||||
const PresetCardFrame: FC<PropsWithChildren & { onClick?: () => void }> = (props) => {
|
||||
const PresetCardFrame: FC<PropsWithChildren & {
|
||||
onClick?: React.MouseEventHandler<HTMLButtonElement>
|
||||
}> = (props) => {
|
||||
return <Button
|
||||
className="flex flex-col gap-1 w-32 h-56 break-all"
|
||||
style={{ minWidth: 0, borderRadius: '0.75rem', justifyContent: 'unset' }}
|
||||
@@ -103,7 +83,10 @@ const PresetCard: FC<{
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return <PresetCardFrame onClick={onClick}>
|
||||
return <PresetCardFrame onClick={(e) => {
|
||||
if (onClick && e.currentTarget.contains(e.target as Node))
|
||||
onClick();
|
||||
}}>
|
||||
<img src={absPathAsset(avatarImg)} className="rounded-xl select-none ml-auto mr-auto h-28" />
|
||||
<Text size={400}>{name}</Text>
|
||||
<Text size={200} style={{
|
||||
@@ -116,7 +99,8 @@ const PresetCard: FC<{
|
||||
{editable ?
|
||||
<ChatPresetEditor presetIndex={presetIndex} triggerButton={
|
||||
<ToolTipButton size="small" appearance="transparent" desc={t('Edit')} icon={<Edit20Regular />}
|
||||
onClick={() => {
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
commonStore.setEditingPreset({ ...commonStore.presets[presetIndex] });
|
||||
}} />
|
||||
} />
|
||||
@@ -371,7 +355,9 @@ const ChatPresets: FC = observer(() => {
|
||||
</div>;
|
||||
});
|
||||
|
||||
const pages: { [label: string]: PresetsNavigationItem } = {
|
||||
const pages: {
|
||||
[label: string]: PresetsNavigationItem
|
||||
} = {
|
||||
Chat: {
|
||||
icon: <Chat20Regular />,
|
||||
element: <ChatPresets />
|
||||
@@ -386,7 +372,9 @@ const pages: { [label: string]: PresetsNavigationItem } = {
|
||||
}
|
||||
};
|
||||
|
||||
const PresetsManager: FC<{ initTab: string }> = ({ initTab }) => {
|
||||
const PresetsManager: FC<{
|
||||
initTab: string
|
||||
}> = ({ initTab }) => {
|
||||
const { t } = useTranslation();
|
||||
const [tab, setTab] = useState(initTab);
|
||||
|
||||
|
||||
@@ -186,6 +186,16 @@ export const AdvancedGeneralSettings: FC = observer(() => {
|
||||
</Dropdown>
|
||||
</div>
|
||||
} />
|
||||
<Labeled label={t('Core API URL')}
|
||||
desc={t('Override core API URL(/chat/completions and /completions). If you don\'t know what this is, leave it blank.')}
|
||||
content={
|
||||
<Input style={{ minWidth: 0 }} className="grow" value={commonStore.settings.coreApiUrl}
|
||||
onChange={(e, data) => {
|
||||
commonStore.setSettings({
|
||||
coreApiUrl: data.value
|
||||
});
|
||||
}} />
|
||||
} />
|
||||
</div>;
|
||||
});
|
||||
|
||||
|
||||
@@ -66,6 +66,11 @@ const parseLossData = (data: string) => {
|
||||
const loss = parseFloat(lastMatch[8]);
|
||||
commonStore.setChartTitle(`Epoch ${epoch}: ${lastMatch[2]} - ${lastMatch[3]}/${lastMatch[4]} - ${lastMatch[5]}/${lastMatch[6]} - ${lastMatch[7]} Loss=${loss}`);
|
||||
addLossDataToChart(epoch, loss);
|
||||
if (loss > 5)
|
||||
toast(t('Loss is too high, please check the training data, and ensure your gpu driver is up to date.'), {
|
||||
type: 'warning',
|
||||
toastId: 'train_loss_high'
|
||||
});
|
||||
return true;
|
||||
};
|
||||
|
||||
@@ -126,7 +131,7 @@ const showError = (e: any) => {
|
||||
};
|
||||
|
||||
const errorsMap = Object.entries({
|
||||
'python3 ./finetune/lora/train.py': 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.',
|
||||
'python3 ./finetune/lora/$modelInfo': 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.',
|
||||
'cuda out of memory': 'VRAM is not enough',
|
||||
'valueerror: high <= 0': 'Training data is not enough, reduce context length or add more data for training',
|
||||
'+= \'+ptx\'': 'Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"',
|
||||
@@ -134,6 +139,7 @@ const errorsMap = Object.entries({
|
||||
'cuda_home environment variable is not set': 'Matched CUDA is not installed',
|
||||
'unsupported gpu architecture': 'Matched CUDA is not installed',
|
||||
'error building extension \'fused_adam\'': 'Matched CUDA is not installed',
|
||||
'rwkv{version} is not supported': 'This version of RWKV is not supported yet.',
|
||||
'modelinfo is invalid': 'Failed to load model, try to increase the virtual memory (Swap of WSL) or use a smaller base model.'
|
||||
});
|
||||
|
||||
@@ -293,7 +299,6 @@ const LoraFinetune: FC = observer(() => {
|
||||
(loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') +
|
||||
(loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
|
||||
`--data_file ${convertedDataPath} ` +
|
||||
`--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` +
|
||||
`--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
|
||||
`--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
|
||||
`--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
|
||||
@@ -464,11 +469,12 @@ const LoraFinetune: FC = observer(() => {
|
||||
return;
|
||||
if (loraParams.loraLoad) {
|
||||
const outputPath = `models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`;
|
||||
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
|
||||
MergeLora(commonStore.settings.customPythonPath, !!commonStore.monitorData && commonStore.monitorData.totalVram !== 0, loraParams.loraAlpha,
|
||||
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
|
||||
outputPath).then(async () => {
|
||||
if (!await FileExists(outputPath)) {
|
||||
toast(t('Failed to merge model') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
|
||||
toast(t('Failed to merge model') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(t('Merge model successfully'), { type: 'success' });
|
||||
}
|
||||
|
||||
@@ -2,16 +2,27 @@ import { CompletionPreset } from '../types/completion';
|
||||
import { ModelConfig } from '../types/configs';
|
||||
|
||||
export const defaultCompositionPrompt = '<pad>';
|
||||
export const defaultCompositionABCPrompt='S:3\n' +
|
||||
'B:9\n' +
|
||||
'E:4\n' +
|
||||
'B:9\n' +
|
||||
'E:4\n' +
|
||||
'E:4\n' +
|
||||
'B:9\n' +
|
||||
'L:1/8\n' +
|
||||
'M:3/4\n' +
|
||||
'K:D\n' +
|
||||
' Bc |"G" d2 cB"A" A2 FE |"Bm" F2 B4 F^G |'
|
||||
|
||||
export const defaultPresets: CompletionPreset[] = [{
|
||||
name: 'Writer',
|
||||
prompt: 'The following is an epic science fiction masterpiece that is immortalized, with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\n',
|
||||
params: {
|
||||
maxResponseToken: 500,
|
||||
temperature: 1.2,
|
||||
topP: 0.5,
|
||||
presencePenalty: 0.4,
|
||||
frequencyPenalty: 0.4,
|
||||
temperature: 1,
|
||||
topP: 0.3,
|
||||
presencePenalty: 0,
|
||||
frequencyPenalty: 1,
|
||||
stop: '\\n\\nUser',
|
||||
injectStart: '',
|
||||
injectEnd: ''
|
||||
@@ -211,7 +222,7 @@ export const defaultModelConfigsMac: ModelConfig[] = [
|
||||
frequencyPenalty: 1
|
||||
},
|
||||
modelParameters: {
|
||||
modelName: 'RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth',
|
||||
modelName: 'RWKV-5-MIDI-120M-v1-20230728-ctx4096.pth',
|
||||
device: 'CPU',
|
||||
precision: 'fp32',
|
||||
storedLayers: 41,
|
||||
@@ -229,7 +240,7 @@ export const defaultModelConfigsMac: ModelConfig[] = [
|
||||
frequencyPenalty: 1
|
||||
},
|
||||
modelParameters: {
|
||||
modelName: 'RWKV-4-MIDI-560M-v1-20230717-ctx4096.pth',
|
||||
modelName: 'RWKV-5-MIDI-560M-v1-20230902-ctx4096.pth',
|
||||
device: 'CPU',
|
||||
precision: 'fp32',
|
||||
storedLayers: 41,
|
||||
@@ -687,7 +698,7 @@ export const defaultModelConfigs: ModelConfig[] = [
|
||||
frequencyPenalty: 1
|
||||
},
|
||||
modelParameters: {
|
||||
modelName: 'RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth',
|
||||
modelName: 'RWKV-5-MIDI-120M-v1-20230728-ctx4096.pth',
|
||||
device: 'CPU',
|
||||
precision: 'fp32',
|
||||
storedLayers: 41,
|
||||
@@ -705,7 +716,7 @@ export const defaultModelConfigs: ModelConfig[] = [
|
||||
frequencyPenalty: 1
|
||||
},
|
||||
modelParameters: {
|
||||
modelName: 'RWKV-4-MIDI-560M-v1-20230717-ctx4096.pth',
|
||||
modelName: 'RWKV-5-MIDI-560M-v1-20230902-ctx4096.pth',
|
||||
device: 'CPU',
|
||||
precision: 'fp32',
|
||||
storedLayers: 41,
|
||||
|
||||
@@ -49,7 +49,7 @@ export async function startup() {
|
||||
async function initRemoteText() {
|
||||
await fetch('https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/manifest.json', { cache: 'no-cache' })
|
||||
.then(r => r.json()).then((data) => {
|
||||
if (data.version > manifest.version) {
|
||||
if (data.version >= manifest.version) {
|
||||
if (data.introduction)
|
||||
commonStore.setIntroduction(data.introduction);
|
||||
if (data.about)
|
||||
@@ -140,7 +140,8 @@ async function initHardwareMonitor() {
|
||||
|
||||
async function initMidi() {
|
||||
EventsOn('midiError', (data: string) => {
|
||||
toast('MIDI Error: ' + data, { type: 'error' });
|
||||
if (commonStore.platform === 'windows')
|
||||
toast('MIDI Error: ' + data, { type: 'error' });
|
||||
});
|
||||
EventsOn('midiPorts', (data: MidiPort[]) => {
|
||||
commonStore.setMidiPorts(data);
|
||||
|
||||
@@ -70,7 +70,9 @@ class CommonStore {
|
||||
conversationOrder: string[] = [];
|
||||
activePreset: Preset | null = null;
|
||||
attachmentUploading: boolean = false;
|
||||
attachments: { [uuid: string]: Attachment[] } = {};
|
||||
attachments: {
|
||||
[uuid: string]: Attachment[]
|
||||
} = {};
|
||||
currentTempAttachment: Attachment | null = null;
|
||||
chatParams: ChatParams = {
|
||||
maxResponseToken: 1000,
|
||||
@@ -92,8 +94,11 @@ class CommonStore {
|
||||
topP: 0.8,
|
||||
autoPlay: true,
|
||||
useLocalSoundFont: false,
|
||||
externalPlay: false,
|
||||
midi: null,
|
||||
ns: null
|
||||
ns: null,
|
||||
generationStartTime: 0,
|
||||
playOnlyGeneratedContent: true
|
||||
};
|
||||
compositionGenerating: boolean = false;
|
||||
compositionSubmittedPrompt: string = defaultCompositionPrompt;
|
||||
@@ -116,6 +121,7 @@ class CommonStore {
|
||||
modelConfigs: ModelConfig[] = [];
|
||||
modelParamsCollapsed: boolean = true;
|
||||
// models
|
||||
activeModelListTags: string[] = [];
|
||||
modelSourceManifestList: string = 'https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/manifest.json;';
|
||||
modelSourceList: ModelSourceItem[] = [];
|
||||
// downloads
|
||||
@@ -171,7 +177,8 @@ class CommonStore {
|
||||
apiUrl: '',
|
||||
apiKey: '',
|
||||
apiChatModelName: 'rwkv',
|
||||
apiCompletionModelName: 'rwkv'
|
||||
apiCompletionModelName: 'rwkv',
|
||||
coreApiUrl: ''
|
||||
};
|
||||
// about
|
||||
about: AboutContent = manifest.about;
|
||||
@@ -327,7 +334,7 @@ class CommonStore {
|
||||
savePresets();
|
||||
}
|
||||
|
||||
setActivePreset(value: Preset) {
|
||||
setActivePreset(value: Preset | null) {
|
||||
this.activePreset = value;
|
||||
}
|
||||
|
||||
@@ -379,7 +386,9 @@ class CommonStore {
|
||||
this.attachmentUploading = value;
|
||||
}
|
||||
|
||||
setAttachments(value: { [uuid: string]: Attachment[] }) {
|
||||
setAttachments(value: {
|
||||
[uuid: string]: Attachment[]
|
||||
}) {
|
||||
this.attachments = value;
|
||||
}
|
||||
|
||||
@@ -449,6 +458,10 @@ class CommonStore {
|
||||
setPlayingTrackId(value: string) {
|
||||
this.playingTrackId = value;
|
||||
}
|
||||
|
||||
setActiveModelListTags(value: string[]) {
|
||||
this.activeModelListTags = value;
|
||||
}
|
||||
}
|
||||
|
||||
export default new CommonStore();
|
||||
@@ -9,11 +9,15 @@ export type CompositionParams = {
|
||||
topP: number,
|
||||
autoPlay: boolean,
|
||||
useLocalSoundFont: boolean,
|
||||
externalPlay: boolean,
|
||||
midi: ArrayBuffer | null,
|
||||
ns: NoteSequence | null
|
||||
ns: NoteSequence | null,
|
||||
generationStartTime: number,
|
||||
playOnlyGeneratedContent: boolean,
|
||||
}
|
||||
export type Track = {
|
||||
id: string;
|
||||
mainInstrument: string;
|
||||
content: string;
|
||||
rawContent: MidiMessage[];
|
||||
offsetTime: number;
|
||||
|
||||
@@ -6,8 +6,8 @@ export type ApiParameters = {
|
||||
presencePenalty: number;
|
||||
frequencyPenalty: number;
|
||||
}
|
||||
export type Device = 'CPU' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
|
||||
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4';
|
||||
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom';
|
||||
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';
|
||||
export type ModelParameters = {
|
||||
// different models can not have the same name
|
||||
modelName: string;
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
export type ModelSourceItem = {
|
||||
name: string;
|
||||
size: number;
|
||||
lastUpdated: string;
|
||||
desc?: { [lang: string]: string | undefined; };
|
||||
size: number;
|
||||
SHA256?: string;
|
||||
lastUpdated: string;
|
||||
url?: string;
|
||||
downloadUrl?: string;
|
||||
tags?: string[];
|
||||
customTokenizer?: string;
|
||||
hide?: boolean;
|
||||
|
||||
lastUpdatedMs?: number;
|
||||
isComplete?: boolean;
|
||||
isLocal?: boolean;
|
||||
localSize?: number;
|
||||
lastUpdatedMs?: number;
|
||||
hide?: boolean;
|
||||
};
|
||||
@@ -19,4 +19,5 @@ export type SettingsType = {
|
||||
apiKey: string
|
||||
apiChatModelName: string
|
||||
apiCompletionModelName: string
|
||||
coreApiUrl: string
|
||||
}
|
||||
118
frontend/src/utils/convert-model.ts
Normal file
118
frontend/src/utils/convert-model.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { toast } from 'react-toastify';
|
||||
import commonStore from '../stores/commonStore';
|
||||
import { t } from 'i18next';
|
||||
import {
|
||||
ConvertGGML,
|
||||
ConvertModel,
|
||||
ConvertSafetensors,
|
||||
ConvertSafetensorsWithPython,
|
||||
FileExists,
|
||||
GetPyError
|
||||
} from '../../wailsjs/go/backend_golang/App';
|
||||
import { WindowShow } from '../../wailsjs/runtime';
|
||||
import { ModelConfig, Precision } from '../types/configs';
|
||||
import { checkDependencies, getStrategy } from './index';
|
||||
import { NavigateFunction } from 'react-router';
|
||||
|
||||
export const convertModel = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
|
||||
if (commonStore.platform === 'darwin') {
|
||||
toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||
return;
|
||||
} else if (commonStore.platform === 'linux') {
|
||||
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||
return;
|
||||
}
|
||||
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
const strategy = getStrategy(selectedConfig);
|
||||
const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
|
||||
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
|
||||
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
|
||||
if (!await FileExists(newModelPath + '.pth')) {
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
export const convertToSt = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
|
||||
const webgpuPython = selectedConfig.modelParameters.device === 'WebGPU (Python)';
|
||||
if (webgpuPython) {
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
}
|
||||
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
|
||||
const newModelPath = modelPath.replace(/\.pth$/, '.st');
|
||||
const convert = webgpuPython ?
|
||||
(input: string, output: string) => ConvertSafetensorsWithPython(commonStore.settings.customPythonPath, input, output)
|
||||
: ConvertSafetensors;
|
||||
convert(modelPath, newModelPath).then(async () => {
|
||||
if (!await FileExists(newModelPath)) {
|
||||
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
};
|
||||
|
||||
export const convertToGGML = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
|
||||
const precision: Precision = selectedConfig.modelParameters.precision === 'Q5_1' ? 'Q5_1' : 'fp16';
|
||||
const newModelPath = modelPath.replace(/\.pth$/, `-${precision}.bin`);
|
||||
ConvertGGML(commonStore.settings.customPythonPath, modelPath, newModelPath, precision === 'Q5_1').then(async () => {
|
||||
if (!await FileExists(newModelPath)) {
|
||||
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
};
|
||||
@@ -1,42 +0,0 @@
|
||||
import { toast } from 'react-toastify';
|
||||
import commonStore from '../stores/commonStore';
|
||||
import { t } from 'i18next';
|
||||
import { checkDependencies } from './index';
|
||||
import { ConvertSafetensors, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
|
||||
import { WindowShow } from '../../wailsjs/runtime';
|
||||
import { NavigateFunction } from 'react-router';
|
||||
import { ModelConfig } from '../types/configs';
|
||||
|
||||
export const convertToSt = async (navigate: NavigateFunction, selectedConfig: ModelConfig) => {
|
||||
if (commonStore.platform === 'linux') {
|
||||
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_safetensors.py)', { type: 'info' });
|
||||
return;
|
||||
}
|
||||
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
|
||||
const newModelPath = modelPath.replace(/\.pth$/, '.st');
|
||||
ConvertSafetensors(commonStore.settings.customPythonPath, modelPath, newModelPath).then(async () => {
|
||||
if (!await FileExists(newModelPath)) {
|
||||
if (commonStore.platform === 'windows')
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
} else {
|
||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||
}
|
||||
}).catch(e => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
else
|
||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||
});
|
||||
setTimeout(WindowShow, 1000);
|
||||
} else {
|
||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||
}
|
||||
};
|
||||
@@ -22,7 +22,12 @@ import { DownloadStatus } from '../types/downloads';
|
||||
import { ModelSourceItem } from '../types/models';
|
||||
import { Language, Languages, SettingsType } from '../types/settings';
|
||||
import { DataProcessParameters, LoraFinetuneParameters } from '../types/train';
|
||||
import { tracksMinimalTotalTime } from '../types/composition';
|
||||
import { InstrumentTypeNameMap, MidiMessage, tracksMinimalTotalTime } from '../types/composition';
|
||||
import logo from '../assets/images/logo.png';
|
||||
import { Preset } from '../types/presets';
|
||||
import { botName, Conversation, MessageType, userName } from '../types/chat';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import { findLastIndex } from 'lodash-es';
|
||||
|
||||
export type Cache = {
|
||||
version: string
|
||||
@@ -40,16 +45,18 @@ export type LocalConfig = {
|
||||
}
|
||||
|
||||
export async function refreshBuiltInModels(readCache: boolean = false) {
|
||||
let cache: { models: ModelSourceItem[] } = { models: [] };
|
||||
let cache: {
|
||||
models: ModelSourceItem[]
|
||||
} = { models: [] };
|
||||
if (readCache)
|
||||
await ReadJson('cache.json').then((cacheData: Cache) => {
|
||||
if (cacheData.models)
|
||||
cache.models = cacheData.models;
|
||||
else cache.models = manifest.models;
|
||||
else cache.models = manifest.models.slice();
|
||||
}).catch(() => {
|
||||
cache.models = manifest.models;
|
||||
cache.models = manifest.models.slice();
|
||||
});
|
||||
else cache.models = manifest.models;
|
||||
else cache.models = manifest.models.slice();
|
||||
|
||||
commonStore.setModelSourceList(cache.models);
|
||||
await saveCache().catch(() => {
|
||||
@@ -57,7 +64,7 @@ export async function refreshBuiltInModels(readCache: boolean = false) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
const modelSuffix = ['.pth', '.st', '.safetensors'];
|
||||
const modelSuffix = ['.pth', '.st', '.safetensors', '.bin'];
|
||||
|
||||
export async function refreshLocalModels(cache: {
|
||||
models: ModelSourceItem[]
|
||||
@@ -73,7 +80,8 @@ export async function refreshLocalModels(cache: {
|
||||
size: d.size,
|
||||
lastUpdated: d.modTime,
|
||||
isComplete: true,
|
||||
isLocal: true
|
||||
isLocal: true,
|
||||
tags: ['Local']
|
||||
}] as ModelSourceItem[];
|
||||
return [];
|
||||
}));
|
||||
@@ -83,12 +91,15 @@ export async function refreshLocalModels(cache: {
|
||||
for (let i = 0; i < cache.models.length; i++) {
|
||||
if (!cache.models[i].lastUpdatedMs)
|
||||
cache.models[i].lastUpdatedMs = Date.parse(cache.models[i].lastUpdated);
|
||||
if (!cache.models[i].tags || !Array.isArray(cache.models[i].tags) || cache.models[i].tags?.length === 0)
|
||||
cache.models[i].tags = ['Other'];
|
||||
|
||||
for (let j = i + 1; j < cache.models.length; j++) {
|
||||
if (!cache.models[j].lastUpdatedMs)
|
||||
cache.models[j].lastUpdatedMs = Date.parse(cache.models[j].lastUpdated);
|
||||
|
||||
if (cache.models[i].name === cache.models[j].name) {
|
||||
const tags = Array.from(new Set([...cache.models[i].tags as string[], ...cache.models[j].tags as string[]]));
|
||||
if (cache.models[i].size <= cache.models[j].size) { // j is local file
|
||||
if (cache.models[i].lastUpdatedMs! < cache.models[j].lastUpdatedMs!) {
|
||||
cache.models[i] = Object.assign({}, cache.models[i], cache.models[j]);
|
||||
@@ -98,6 +109,7 @@ export async function refreshLocalModels(cache: {
|
||||
} // else is not complete local file
|
||||
cache.models[i].isLocal = true;
|
||||
cache.models[i].localSize = cache.models[j].size;
|
||||
cache.models[i].tags = tags;
|
||||
cache.models.splice(j, 1);
|
||||
j--;
|
||||
}
|
||||
@@ -132,7 +144,9 @@ function initLastUnfinishedModelDownloads() {
|
||||
commonStore.setLastUnfinishedModelDownloads(list);
|
||||
}
|
||||
|
||||
export async function refreshRemoteModels(cache: { models: ModelSourceItem[] }) {
|
||||
export async function refreshRemoteModels(cache: {
|
||||
models: ModelSourceItem[]
|
||||
}, filter: boolean = true, initUnfinishedModels: boolean = false) {
|
||||
const manifestUrls = commonStore.modelSourceManifestList.split(/[,,;;\n]/);
|
||||
const requests = manifestUrls.filter(url => url.endsWith('.json')).map(
|
||||
url => fetch(url, { cache: 'no-cache' }).then(r => r.json()));
|
||||
@@ -149,18 +163,16 @@ export async function refreshRemoteModels(cache: { models: ModelSourceItem[] })
|
||||
});
|
||||
cache.models = cache.models.filter((model, index, self) => {
|
||||
return modelSuffix.some((ext => model.name.endsWith(ext)))
|
||||
&& index === self.findIndex(
|
||||
m => m.name === model.name || (m.SHA256 && m.SHA256 === model.SHA256 && m.size === model.size));
|
||||
});
|
||||
commonStore.setModelSourceList(cache.models);
|
||||
await saveCache().catch(() => {
|
||||
&& index === findLastIndex(self,
|
||||
m => m.name === model.name || (!!m.SHA256 && m.SHA256 === model.SHA256 && m.size === model.size));
|
||||
});
|
||||
await refreshLocalModels(cache, filter, initUnfinishedModels);
|
||||
}
|
||||
|
||||
export const refreshModels = async (readCache: boolean = false, initUnfinishedModels: boolean = false) => {
|
||||
const cache = await refreshBuiltInModels(readCache);
|
||||
await refreshLocalModels(cache, false, initUnfinishedModels);
|
||||
await refreshRemoteModels(cache);
|
||||
await refreshRemoteModels(cache, false, initUnfinishedModels);
|
||||
};
|
||||
|
||||
export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) => {
|
||||
@@ -179,6 +191,7 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
|
||||
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
|
||||
break;
|
||||
case 'WebGPU':
|
||||
case 'WebGPU (Python)':
|
||||
strategy += params.precision === 'nf4' ? 'fp16i4' : params.precision === 'int8' ? 'fp16i8' : 'fp16';
|
||||
break;
|
||||
case 'CUDA':
|
||||
@@ -189,6 +202,8 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
|
||||
strategy += params.precision === 'int8' ? 'fp16i8' : params.precision === 'fp32' ? 'fp32' : 'fp16';
|
||||
if (params.storedLayers < params.maxStoredLayers)
|
||||
strategy += ` *${params.storedLayers}+`;
|
||||
else
|
||||
strategy += ` -> cuda fp16 *1`;
|
||||
break;
|
||||
case 'MPS':
|
||||
if (avoidOverflow)
|
||||
@@ -290,7 +305,11 @@ export function bytesToReadable(size: number) {
|
||||
else return bytesToGb(size) + ' GB';
|
||||
}
|
||||
|
||||
export function getServerRoot(defaultLocalPort: number) {
|
||||
export function getServerRoot(defaultLocalPort: number, isCore: boolean = false) {
|
||||
const coreCustomApiUrl = commonStore.settings.coreApiUrl.trim().replace(/\/$/, '');
|
||||
if (isCore && coreCustomApiUrl)
|
||||
return coreCustomApiUrl;
|
||||
|
||||
const defaultRoot = `http://127.0.0.1:${defaultLocalPort}`;
|
||||
if (commonStore.status.status !== ModelStatus.Offline)
|
||||
return defaultRoot;
|
||||
@@ -305,6 +324,8 @@ export function getServerRoot(defaultLocalPort: number) {
|
||||
export function absPathAsset(path: string) {
|
||||
if (commonStore.platform === 'web')
|
||||
return path;
|
||||
if (path === logo)
|
||||
return path;
|
||||
if ((path.length > 0 && path[0] === '/') ||
|
||||
(path.length > 1 && path[1] === ':')) {
|
||||
return '=>' + path;
|
||||
@@ -418,7 +439,10 @@ export const checkDependencies = async (navigate: NavigateFunction) => {
|
||||
toastWithButton(`${t('Downloading')} Python`, t('Check'), () => {
|
||||
navigate({ pathname: '/downloads' });
|
||||
}, { autoClose: 3000 });
|
||||
AddToDownloadList('python-3.10.11-embed-amd64.zip', 'https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip');
|
||||
AddToDownloadList('python-3.10.11-embed-amd64.zip',
|
||||
!commonStore.settings.cnMirror
|
||||
? 'https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip'
|
||||
: 'https://mirrors.huaweicloud.com/python/3.10.11/python-3.10.11-embed-amd64.zip');
|
||||
});
|
||||
} else if (depErrorMsg.includes('DepCheck Error')) {
|
||||
if (depErrorMsg.includes('vc_redist') || depErrorMsg.includes('DLL load failed while importing')) {
|
||||
@@ -482,6 +506,12 @@ export function getHfDownloadUrl(url: string) {
|
||||
}
|
||||
|
||||
export function refreshTracksTotalTime() {
|
||||
if (commonStore.tracks.length === 0) {
|
||||
commonStore.setTrackTotalTime(tracksMinimalTotalTime);
|
||||
commonStore.setTrackCurrentTime(0);
|
||||
commonStore.setTrackPlayStartTime(0);
|
||||
return;
|
||||
}
|
||||
const endTimes = commonStore.tracks.map(t => t.offsetTime + t.contentTime);
|
||||
const totalTime = Math.max(...endTimes) + tracksMinimalTotalTime;
|
||||
if (commonStore.trackPlayStartTime > totalTime)
|
||||
@@ -489,20 +519,39 @@ export function refreshTracksTotalTime() {
|
||||
commonStore.setTrackTotalTime(totalTime);
|
||||
}
|
||||
|
||||
export function getMidiRawContentTime(rawContent: MidiMessage[]) {
|
||||
return rawContent.reduce((sum, current) =>
|
||||
sum + (current.messageType === 'ElapsedTime' ? current.value : 0)
|
||||
, 0);
|
||||
}
|
||||
|
||||
export function getMidiRawContentMainInstrument(rawContent: MidiMessage[]) {
|
||||
const sortedInstrumentFrequency = Object.entries(rawContent
|
||||
.filter(c => c.messageType === 'NoteOn')
|
||||
.map(c => c.instrument)
|
||||
.reduce((frequencyCount, current) => (frequencyCount[current] = (frequencyCount[current] || 0) + 1, frequencyCount)
|
||||
, {} as {
|
||||
[key: string]: number
|
||||
}))
|
||||
.sort((a, b) => b[1] - a[1]);
|
||||
let mainInstrument: string = '';
|
||||
if (sortedInstrumentFrequency.length > 0)
|
||||
mainInstrument = InstrumentTypeNameMap[Number(sortedInstrumentFrequency[0][0])];
|
||||
return mainInstrument;
|
||||
}
|
||||
|
||||
export function flushMidiRecordingContent() {
|
||||
const recordingTrackIndex = commonStore.tracks.findIndex(t => t.id === commonStore.recordingTrackId);
|
||||
if (recordingTrackIndex >= 0) {
|
||||
const recordingTrack = commonStore.tracks[recordingTrackIndex];
|
||||
const tracks = commonStore.tracks.slice();
|
||||
const contentTime = commonStore.recordingRawContent
|
||||
.reduce((sum, current) =>
|
||||
sum + (current.messageType === 'ElapsedTime' ? current.value : 0)
|
||||
, 0);
|
||||
|
||||
tracks[recordingTrackIndex] = {
|
||||
...recordingTrack,
|
||||
content: commonStore.recordingContent,
|
||||
rawContent: commonStore.recordingRawContent,
|
||||
contentTime: contentTime
|
||||
contentTime: getMidiRawContentTime(commonStore.recordingRawContent),
|
||||
mainInstrument: getMidiRawContentMainInstrument(commonStore.recordingRawContent)
|
||||
};
|
||||
commonStore.setTracks(tracks);
|
||||
refreshTracksTotalTime();
|
||||
@@ -518,7 +567,7 @@ export async function getSoundFont() {
|
||||
else
|
||||
soundUrl = !commonStore.settings.giteeUpdatesSource ?
|
||||
`https://raw.githubusercontent.com/josStorer/sgm_plus/master` :
|
||||
`https://gitee.com/josc146/sgm_plus/raw/master`;
|
||||
`https://cdn.jsdelivr.net/gh/josstorer/sgm_plus`;
|
||||
const fallbackUrl = 'https://cdn.jsdelivr.net/gh/josstorer/sgm_plus';
|
||||
await fetch(soundUrl + '/soundfont.json').then(r => {
|
||||
if (!r.ok)
|
||||
@@ -527,6 +576,30 @@ export async function getSoundFont() {
|
||||
return soundUrl;
|
||||
}
|
||||
|
||||
export const setActivePreset = (preset: Preset | null) => {
|
||||
commonStore.setActivePreset(preset);
|
||||
//TODO if (preset.displayPresetMessages) {
|
||||
const conversation: Conversation = {};
|
||||
const conversationOrder: string[] = [];
|
||||
if (preset)
|
||||
for (const message of preset.messages) {
|
||||
const newUuid = uuid();
|
||||
conversationOrder.push(newUuid);
|
||||
conversation[newUuid] = {
|
||||
sender: message.role === 'user' ? userName : botName,
|
||||
type: MessageType.Normal,
|
||||
color: message.role === 'user' ? 'brand' : 'colorful',
|
||||
time: new Date().toISOString(),
|
||||
content: message.content,
|
||||
side: message.role === 'user' ? 'right' : 'left',
|
||||
done: true
|
||||
};
|
||||
}
|
||||
commonStore.setConversation(conversation);
|
||||
commonStore.setConversationOrder(conversationOrder);
|
||||
//}
|
||||
};
|
||||
|
||||
export function getSupportedCustomCudaFile(isBeta: boolean) {
|
||||
if ([' 10', ' 16', ' 20', ' 30', 'MX', 'Tesla P', 'Quadro P', 'NVIDIA P', 'TITAN X', 'TITAN RTX', 'RTX A',
|
||||
'Quadro RTX 4000', 'Quadro RTX 5000', 'Tesla T4', 'NVIDIA A10', 'NVIDIA A40'].some(v => commonStore.status.device_name.includes(v)))
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import { getDocument, GlobalWorkerOptions, PDFDocumentProxy } from 'pdfjs-dist';
|
||||
import { TextItem } from 'pdfjs-dist/types/src/display/api';
|
||||
|
||||
export function webOpenOpenFileDialog({ filterPattern, fnStartLoading }: { filterPattern: string, fnStartLoading: Function | null }): Promise<{ blob: Blob, content?: string }> {
|
||||
export function webOpenOpenFileDialog(filterPattern: string, fnStartLoading: Function | undefined): Promise<{
|
||||
blob: Blob,
|
||||
content?: string
|
||||
}> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const input = document.createElement('input');
|
||||
input.type = 'file';
|
||||
input.accept = filterPattern
|
||||
.replaceAll('*.txt', 'text/plain')
|
||||
.replace('*.midi', 'audio/midi')
|
||||
.replace('*.mid', 'audio/midi')
|
||||
.replaceAll('*.', 'application/')
|
||||
.replaceAll(';', ',');
|
||||
|
||||
@@ -15,7 +20,7 @@ export function webOpenOpenFileDialog({ filterPattern, fnStartLoading }: { filte
|
||||
const file: Blob = e.target?.files[0];
|
||||
if (fnStartLoading && typeof fnStartLoading === 'function')
|
||||
fnStartLoading();
|
||||
if (!GlobalWorkerOptions.workerSrc)
|
||||
if (!GlobalWorkerOptions.workerSrc && file.type === 'application/pdf')
|
||||
// @ts-ignore
|
||||
GlobalWorkerOptions.workerSrc = await import('pdfjs-dist/build/pdf.worker.min.mjs');
|
||||
if (file.type === 'text/plain') {
|
||||
|
||||
@@ -48,6 +48,8 @@ if (!window.go) {
|
||||
// not implemented
|
||||
defineApp('AddToDownloadList', async () => {
|
||||
})
|
||||
defineApp('CloseMidiPort', async () => {
|
||||
})
|
||||
defineApp('ContinueDownload', async () => {
|
||||
})
|
||||
defineApp('ConvertData', async () => {
|
||||
@@ -74,8 +76,12 @@ if (!window.go) {
|
||||
})
|
||||
defineApp('OpenFileFolder', async () => {
|
||||
})
|
||||
defineApp('OpenMidiPort', async () => {
|
||||
})
|
||||
defineApp('PauseDownload', async () => {
|
||||
})
|
||||
defineApp('PlayNote', async () => {
|
||||
})
|
||||
defineApp('ReadFileInfo', async () => {
|
||||
})
|
||||
defineApp('RestartApp', async () => {
|
||||
|
||||
@@ -12,7 +12,7 @@ const vendor = [
|
||||
'mobx', 'mobx-react-lite',
|
||||
'i18next', 'react-i18next',
|
||||
'usehooks-ts', 'react-toastify',
|
||||
'classnames'
|
||||
'classnames', 'lodash-es'
|
||||
];
|
||||
|
||||
const embedded = [
|
||||
@@ -56,7 +56,10 @@ export default defineConfig({
|
||||
manualChunks: {
|
||||
vendor,
|
||||
...renderChunks(dependencies)
|
||||
}
|
||||
},
|
||||
entryFileNames: `assets/[name].js`,
|
||||
chunkFileNames: `assets/[name].js`,
|
||||
assetFileNames: `assets/[name].[ext]`
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
12
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
Executable file → Normal file
12
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
Executable file → Normal file
@@ -10,9 +10,13 @@ export function ContinueDownload(arg1:string):Promise<void>;
|
||||
|
||||
export function ConvertData(arg1:string,arg2:string,arg3:string,arg4:string):Promise<string>;
|
||||
|
||||
export function ConvertGGML(arg1:string,arg2:string,arg3:string,arg4:boolean):Promise<string>;
|
||||
|
||||
export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Promise<string>;
|
||||
|
||||
export function ConvertSafetensors(arg1:string,arg2:string,arg3:string):Promise<string>;
|
||||
export function ConvertSafetensors(arg1:string,arg2:string):Promise<string>;
|
||||
|
||||
export function ConvertSafetensorsWithPython(arg1:string,arg2:string,arg3:string):Promise<string>;
|
||||
|
||||
export function CopyFile(arg1:string,arg2:string):Promise<void>;
|
||||
|
||||
@@ -56,9 +60,13 @@ export function ReadJson(arg1:string):Promise<any>;
|
||||
|
||||
export function RestartApp():Promise<void>;
|
||||
|
||||
export function SaveFile(arg1:string,arg2:Array<number>):Promise<void>;
|
||||
|
||||
export function SaveJson(arg1:string,arg2:any):Promise<void>;
|
||||
|
||||
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean):Promise<string>;
|
||||
export function StartFile(arg1:string):Promise<void>;
|
||||
|
||||
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean,arg6:boolean,arg7:boolean):Promise<string>;
|
||||
|
||||
export function StartWebGPUServer(arg1:number,arg2:string):Promise<string>;
|
||||
|
||||
|
||||
24
frontend/wailsjs/go/backend_golang/App.js
generated
Executable file → Normal file
24
frontend/wailsjs/go/backend_golang/App.js
generated
Executable file → Normal file
@@ -18,12 +18,20 @@ export function ConvertData(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['backend_golang']['App']['ConvertData'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ConvertGGML(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['backend_golang']['App']['ConvertGGML'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ConvertModel(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ConvertSafetensors(arg1, arg2, arg3) {
|
||||
return window['go']['backend_golang']['App']['ConvertSafetensors'](arg1, arg2, arg3);
|
||||
export function ConvertSafetensors(arg1, arg2) {
|
||||
return window['go']['backend_golang']['App']['ConvertSafetensors'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function ConvertSafetensorsWithPython(arg1, arg2, arg3) {
|
||||
return window['go']['backend_golang']['App']['ConvertSafetensorsWithPython'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function CopyFile(arg1, arg2) {
|
||||
@@ -110,12 +118,20 @@ export function RestartApp() {
|
||||
return window['go']['backend_golang']['App']['RestartApp']();
|
||||
}
|
||||
|
||||
export function SaveFile(arg1, arg2) {
|
||||
return window['go']['backend_golang']['App']['SaveFile'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function SaveJson(arg1, arg2) {
|
||||
return window['go']['backend_golang']['App']['SaveJson'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function StartServer(arg1, arg2, arg3, arg4, arg5) {
|
||||
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5);
|
||||
export function StartFile(arg1) {
|
||||
return window['go']['backend_golang']['App']['StartFile'](arg1);
|
||||
}
|
||||
|
||||
export function StartServer(arg1, arg2, arg3, arg4, arg5, arg6, arg7) {
|
||||
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5, arg6, arg7);
|
||||
}
|
||||
|
||||
export function StartWebGPUServer(arg1, arg2) {
|
||||
|
||||
0
frontend/wailsjs/go/models.ts
generated
Executable file → Normal file
0
frontend/wailsjs/go/models.ts
generated
Executable file → Normal file
14
go.mod
14
go.mod
@@ -9,13 +9,14 @@ require (
|
||||
github.com/minio/selfupdate v0.6.0
|
||||
github.com/nyaosorg/go-windows-su v0.2.1
|
||||
github.com/ubuntu/gowsl v0.0.0-20230615094051-94945650cc1e
|
||||
github.com/wailsapp/wails/v2 v2.6.0
|
||||
github.com/wailsapp/wails/v2 v2.7.1
|
||||
)
|
||||
|
||||
require (
|
||||
aead.dev/minisign v0.2.0 // indirect
|
||||
github.com/bep/debounce v1.2.1 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect
|
||||
github.com/labstack/echo/v4 v4.10.2 // indirect
|
||||
@@ -23,6 +24,7 @@ require (
|
||||
github.com/leaanthony/go-ansi-parser v1.6.0 // indirect
|
||||
github.com/leaanthony/gosod v1.0.3 // indirect
|
||||
github.com/leaanthony/slicer v1.6.0 // indirect
|
||||
github.com/leaanthony/u v1.1.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
|
||||
@@ -34,11 +36,11 @@ require (
|
||||
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||
github.com/wailsapp/go-webview2 v1.0.1 // indirect
|
||||
github.com/wailsapp/go-webview2 v1.0.10 // indirect
|
||||
github.com/wailsapp/mimetype v1.4.1 // indirect
|
||||
golang.org/x/crypto v0.9.0 // indirect
|
||||
golang.org/x/crypto v0.14.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/sys v0.9.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
)
|
||||
|
||||
28
go.sum
28
go.sum
@@ -12,6 +12,8 @@ github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e h1:Q3+PugElBCf4PFpxhErSzU3/PY5sFL5Z6rfv4AbGAck=
|
||||
@@ -29,6 +31,8 @@ github.com/leaanthony/gosod v1.0.3/go.mod h1:BJ2J+oHsQIyIQpnLPjnqFGTMnOZXDbvWtRC
|
||||
github.com/leaanthony/slicer v1.5.0/go.mod h1:FwrApmf8gOrpzEWM2J/9Lh79tyq8KTX5AzRtwV7m4AY=
|
||||
github.com/leaanthony/slicer v1.6.0 h1:1RFP5uiPJvT93TAHi+ipd3NACobkW53yUiBqZheE/Js=
|
||||
github.com/leaanthony/slicer v1.6.0/go.mod h1:o/Iz29g7LN0GqH3aMjWAe90381nyZlDNquK+mtH2Fj8=
|
||||
github.com/leaanthony/u v1.1.0 h1:2n0d2BwPVXSUq5yhe8lJPHdxevE2qK5G99PMStMZMaI=
|
||||
github.com/leaanthony/u v1.1.0/go.mod h1:9+o6hejoRljvZ3BzdYlVL0JYCwtnAsVuN9pVTQcaRfI=
|
||||
github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE=
|
||||
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||
github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
|
||||
@@ -71,24 +75,24 @@ github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyC
|
||||
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo=
|
||||
github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
github.com/wailsapp/go-webview2 v1.0.1 h1:dEJIeEApW/MhO2tTMISZBFZPuW7kwrFA1NtgFB1z1II=
|
||||
github.com/wailsapp/go-webview2 v1.0.1/go.mod h1:Uk2BePfCRzttBBjFrBmqKGJd41P6QIHeV9kTgIeOZNo=
|
||||
github.com/wailsapp/go-webview2 v1.0.10 h1:PP5Hug6pnQEAhfRzLCoOh2jJaPdrqeRgJKZhyYyDV/w=
|
||||
github.com/wailsapp/go-webview2 v1.0.10/go.mod h1:Uk2BePfCRzttBBjFrBmqKGJd41P6QIHeV9kTgIeOZNo=
|
||||
github.com/wailsapp/mimetype v1.4.1 h1:pQN9ycO7uo4vsUUuPeHEYoUkLVkaRntMnHJxVwYhwHs=
|
||||
github.com/wailsapp/mimetype v1.4.1/go.mod h1:9aV5k31bBOv5z6u+QP8TltzvNGJPmNJD4XlAL3U+j3o=
|
||||
github.com/wailsapp/wails/v2 v2.6.0 h1:EyH0zR/EO6dDiqNy8qU5spaXDfkluiq77xrkabPYD4c=
|
||||
github.com/wailsapp/wails/v2 v2.6.0/go.mod h1:WBG9KKWuw0FKfoepBrr/vRlyTmHaMibWesK3yz6nNiM=
|
||||
github.com/wailsapp/wails/v2 v2.7.1 h1:HAzp2c5ODOzsLC6ZMDVtNOB72ozM7/SJecJPB2Ur+UU=
|
||||
github.com/wailsapp/wails/v2 v2.7.1/go.mod h1:oIJVwwso5fdOgprBYWXBBqtx6PaSvxg8/KTQHNGkadc=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
|
||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20210505024714-0287a6fb4125/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -105,14 +109,14 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
||||
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
14
main.go
14
main.go
@@ -11,6 +11,7 @@ import (
|
||||
backend "rwkv-runner/backend-golang"
|
||||
|
||||
"github.com/wailsapp/wails/v2"
|
||||
wailsLogger "github.com/wailsapp/wails/v2/pkg/logger"
|
||||
"github.com/wailsapp/wails/v2/pkg/options"
|
||||
"github.com/wailsapp/wails/v2/pkg/options/assetserver"
|
||||
"github.com/wailsapp/wails/v2/pkg/options/windows"
|
||||
@@ -66,7 +67,10 @@ var midiAssets embed.FS
|
||||
var components embed.FS
|
||||
|
||||
func main() {
|
||||
dev := true
|
||||
if buildInfo, ok := debug.ReadBuildInfo(); !ok || strings.Contains(buildInfo.String(), "-ldflags") {
|
||||
dev = false
|
||||
|
||||
backend.CopyEmbed(assets)
|
||||
os.RemoveAll("./py310/Lib/site-packages/cyac-1.7.dist-info")
|
||||
backend.CopyEmbed(cyac)
|
||||
@@ -94,11 +98,18 @@ func main() {
|
||||
app.HasConfigData = false
|
||||
}
|
||||
|
||||
var logger wailsLogger.Logger
|
||||
if dev {
|
||||
logger = wailsLogger.NewDefaultLogger()
|
||||
} else {
|
||||
logger = wailsLogger.NewFileLogger("crash.log")
|
||||
}
|
||||
|
||||
// Create application with options
|
||||
err = wails.Run(&options.App{
|
||||
Title: "RWKV-Runner",
|
||||
Width: 1024,
|
||||
Height: 680,
|
||||
Height: 700,
|
||||
MinWidth: 375,
|
||||
MinHeight: 640,
|
||||
EnableDefaultContextMenu: true,
|
||||
@@ -115,6 +126,7 @@ func main() {
|
||||
Bind: []any{
|
||||
app,
|
||||
},
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
||||
596
manifest.json
596
manifest.json
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"version": "1.5.4",
|
||||
"version": "1.6.7",
|
||||
"introduction": {
|
||||
"en": "RWKV is an open-source, commercially usable large language model with high flexibility and great potential for development.\n### About This Tool\nThis tool aims to lower the barrier of entry for using large language models, making it accessible to everyone. It provides fully automated dependency and model management. You simply need to click and run, following the instructions, to deploy a local large language model. The tool itself is very compact and only requires a single executable file for one-click deployment.\nAdditionally, this tool offers an interface that is fully compatible with the OpenAI API. This means you can use any ChatGPT client as a client for RWKV, enabling capability expansion beyond just chat functionality.\n### Preset Configuration Rules at the Bottom\nThis tool comes with a series of preset configurations to reduce complexity. The naming rules for each configuration represent the following in order: device - required VRAM/memory - model size - model language.\nFor example, \"GPU-8G-3B-EN\" indicates that this configuration is for a graphics card with 8GB of VRAM, a model size of 3 billion parameters, and it uses an English language model.\nLarger model sizes have higher performance and VRAM requirements. Among configurations with the same model size, those with higher VRAM usage will have faster runtime.\nFor example, if you have 12GB of VRAM but running the \"GPU-12G-7B-EN\" configuration is slow, you can downgrade to \"GPU-8G-3B-EN\" for a significant speed improvement.\n### About RWKV\nRWKV is an RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it's 100% attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the \"GPT\" mode to quickly compute the hidden state for the \"RNN\" mode.<br/>So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, \"infinite\" ctx_len, and free sentence embedding (using the final hidden state).",
|
||||
"zh": "RWKV是一个开源且允许商用的大语言模型,灵活性很高且极具发展潜力。\n### 关于本工具\n本工具旨在降低大语言模型的使用门槛,做到人人可用,本工具提供了全自动化的依赖和模型管理,你只需要直接点击运行,跟随引导,即可完成本地大语言模型的部署,工具本身体积极小,只需要一个exe即可完成一键部署。\n此外,本工具提供了与OpenAI API完全兼容的接口,这意味着你可以把任意ChatGPT客户端用作RWKV的客户端,实现能力拓展,而不局限于聊天。\n### 底部的预设配置规则\n本工具内置了一系列预设配置,以降低使用难度,每个配置名的规则,依次代表着:设备-所需显存/内存-模型规模-模型语言。\n例如,GPU-8G-3B-CN,表示该配置用于显卡,需要8G显存,模型规模为30亿参数,使用的是中文模型。\n模型规模越大,性能要求越高,显存要求也越高,而同样模型规模的配置中,显存占用越高的,运行速度越快。\n例如当你有12G显存,但运行GPU-12G-7B-CN配置速度比较慢,可降级成GPU-8G-3B-CN,将会大幅提速。\n### 关于RWKV\nRWKV是具有Transformer级别LLM性能的RNN,也可以像GPT Transformer一样直接进行训练(可并行化)。而且它是100% attention-free的。你只需在位置t处获得隐藏状态即可计算位置t + 1处的状态。你可以使用“GPT”模式快速计算用于“RNN”模式的隐藏状态。\n因此,它将RNN和Transformer的优点结合起来 - 高性能、快速推理、节省显存、快速训练、“无限”上下文长度以及免费的语句嵌入(使用最终隐藏状态)。"
|
||||
},
|
||||
"about": {
|
||||
"en": "<div align=\"center\">\n\nProject Source Code:\nhttps://github.com/josStorer/RWKV-Runner\nAuthor: [@josStorer](https://github.com/josStorer)\nFAQs: https://github.com/josStorer/RWKV-Runner/wiki/FAQs\n\nRelated Repositories:\nRWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main\nRWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main\nChatRWKV: https://github.com/BlinkDL/ChatRWKV\nRWKV-LM: https://github.com/BlinkDL/RWKV-LM\nRWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA\nMIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer\n\n</div>",
|
||||
"zh": "<div align=\"center\">\n\n本项目源码:\nhttps://github.com/josStorer/RWKV-Runner\n作者: [@josStorer](https://github.com/josStorer)\n演示与常见问题说明视频: https://www.bilibili.com/video/BV1hM4y1v76R\n疑难解答: https://www.bilibili.com/read/cv23921171\n\n相关仓库:\nRWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main\nRWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main\nChatRWKV: https://github.com/BlinkDL/ChatRWKV\nRWKV-LM: https://github.com/BlinkDL/RWKV-LM\nRWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA\nMIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer\n\n</div>"
|
||||
"en": "<div align=\"center\">\n\nProject Source Code and Introduction:\nhttps://github.com/josStorer/RWKV-Runner\nAuthor: [@josStorer](https://github.com/josStorer)\n\nRelated Repositories:\nRWKV-5-World: https://huggingface.co/BlinkDL/rwkv-5-world/tree/main\nRWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main\nRWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main\nChatRWKV: https://github.com/BlinkDL/ChatRWKV\nRWKV-LM: https://github.com/BlinkDL/RWKV-LM\nRWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA\nMIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer\nai00_rwkv_server: https://github.com/cgisky1980/ai00_rwkv_server\nrwkv.cpp: https://github.com/saharNooby/rwkv.cpp\nweb-rwkv-py: https://github.com/cryscan/web-rwkv-py\n\n</div>",
|
||||
"zh": "<div align=\"center\">\n\n本项目源码及介绍页:\nhttps://github.com/josStorer/RWKV-Runner\n作者: [@josStorer](https://github.com/josStorer)\n演示与常见问题说明视频: https://www.bilibili.com/video/BV1hM4y1v76R\n\n相关仓库:\nRWKV-5-World: https://huggingface.co/BlinkDL/rwkv-5-world/tree/main\nRWKV-4-World: https://huggingface.co/BlinkDL/rwkv-4-world/tree/main\nRWKV-4-Raven: https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main\nChatRWKV: https://github.com/BlinkDL/ChatRWKV\nRWKV-LM: https://github.com/BlinkDL/RWKV-LM\nRWKV-LM-LoRA: https://github.com/Blealtan/RWKV-LM-LoRA\nMIDI-LLM-tokenizer: https://github.com/briansemrau/MIDI-LLM-tokenizer\nai00_rwkv_server: https://github.com/cgisky1980/ai00_rwkv_server\nrwkv.cpp: https://github.com/saharNooby/rwkv.cpp\nweb-rwkv-py: https://github.com/cryscan/web-rwkv-py\n\n</div>"
|
||||
},
|
||||
"programFiles": [
|
||||
{
|
||||
@@ -25,8 +25,13 @@
|
||||
"size": 385598386,
|
||||
"SHA256": "c844a3ee05bcb9065848cb05b10c48a3f381f5ac1953aad89e156ecdf31d7703",
|
||||
"lastUpdated": "2023-08-03T15:18:46",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-0.1B-v1-20230803-ctx4096.pth?download=true",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-0.1B-v1-20230803-ctx4096.pth?download=true"
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-0.1B-v1-20230803-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-0.1B-v1-20230803-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-5",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-World-0.4B-v2-20231113-ctx4096.pth",
|
||||
@@ -38,8 +43,13 @@
|
||||
"size": 923523954,
|
||||
"SHA256": "5a288c54c7f30b0e2d4af23991133fad2af2d5e59ec7ad850ffe78054a5e4f92",
|
||||
"lastUpdated": "2023-11-14T01:23:49",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth?download=true",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth?download=true"
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-5",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-World-1B5-v2-20231025-ctx4096.pth",
|
||||
@@ -52,7 +62,51 @@
|
||||
"SHA256": "5a89f56be7f82ab9dd0835af9a6838f788477471616c02f7b041e3aea0c57435",
|
||||
"lastUpdated": "2023-10-26T05:49:30",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-1B5-v2-20231025-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-1B5-v2-20231025-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-1B5-v2-20231025-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-5",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-1B5-one-state-slim.pth",
|
||||
"desc": {
|
||||
"en": "RWKV-5 Global Languages 1.5B v2 Ctx16k Role Play",
|
||||
"zh": "RWKV-5 全球语言 1.5B v2 16k上下文 角色扮演",
|
||||
"ja": "RWKV-5 グローバル言語 1.5B v2 16kコンテキスト ロールプレイ"
|
||||
},
|
||||
"size": 3155589871,
|
||||
"SHA256": "43e7b922d7ad49eafa17f8909c2813c91394925bc7f24caf0e19a91aa3281273",
|
||||
"lastUpdated": "2023-11-02T04:03:27",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-v5-world-v2-1.5B-one-state-slim-16k/blob/main/RWKV-5-1B5-one-state-slim.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-v5-world-v2-1.5B-one-state-slim-16k/resolve/main/RWKV-5-1B5-one-state-slim.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-5",
|
||||
"Global",
|
||||
"Role Play"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-1B5-one-state-slim-novel-tuned.pth",
|
||||
"desc": {
|
||||
"en": "RWKV-5 Global Languages 1.5B v2 Ctx16k Novel",
|
||||
"zh": "RWKV-5 全球语言 1.5B v2 16k上下文 小说",
|
||||
"ja": "RWKV-5 グローバル言語 1.5B v2 16kコンテキスト 小説"
|
||||
},
|
||||
"size": 3155589871,
|
||||
"SHA256": "4f0aaecdce676e5236018ebd63e3d37c2f300fbac04001ee3a9c00d2f4244d0f",
|
||||
"lastUpdated": "2023-11-03T02:45:52",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-v5-world-v2-1.5B-one-state-slim-16k-novel-tuned/blob/main/RWKV-5-1B5-one-state-slim-novel-tuned.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-v5-world-v2-1.5B-one-state-slim-16k-novel-tuned/resolve/main/RWKV-5-1B5-one-state-slim-novel-tuned.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-5",
|
||||
"Global"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-World-3B-v2-20231113-ctx4096.pth",
|
||||
@@ -64,8 +118,14 @@
|
||||
"size": 6126106674,
|
||||
"SHA256": "a4bd430343c6fd138b85bbc68bb20262d3a2f053ea57dc4b41078269af68ff9c",
|
||||
"lastUpdated": "2023-11-14T01:23:49",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-3B-v2-20231113-ctx4096.pth?download=true",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-3B-v2-20231113-ctx4096.pth?download=true"
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-3B-v2-20231113-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-3B-v2-20231113-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-5",
|
||||
"Global",
|
||||
"Recommended"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-World-3B-v2-20231118-ctx16k.pth",
|
||||
@@ -77,8 +137,93 @@
|
||||
"size": 6126106467,
|
||||
"SHA256": "efa5178d1c824b94ef17c6c9a456674e5581a8be832becbda9aba4dc533f88c2",
|
||||
"lastUpdated": "2023-11-19T04:21:04",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-3B-v2-20231118-ctx16k.pth?download=true",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-3B-v2-20231118-ctx16k.pth?download=true"
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-3B-v2-20231118-ctx16k.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-world/resolve/main/RWKV-5-World-3B-v2-20231118-ctx16k.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-5",
|
||||
"Global",
|
||||
"Recommended"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "rwkv-v5-7B-0.4-long-ctx-16k.pth",
|
||||
"desc": {
|
||||
"en": "RWKV-5 Global Languages 7B v2 40% Ctx300k Document Reader",
|
||||
"zh": "RWKV-5 全球语言 7B v2 40% 300k上下文 文档阅读",
|
||||
"ja": "RWKV-5 グローバル言語 7B v2 40% 300kコンテキスト ドキュメントリーダー"
|
||||
},
|
||||
"size": 15036198115,
|
||||
"SHA256": "5888471a45caab903c1bd9c35af1c639ac8d03be6ee6eb39fa9fd3194fa6d437",
|
||||
"lastUpdated": "2023-11-10T17:12:04",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-5-world-v2-7B-0.4-300k/blob/main/rwkv-v5-7B-0.4-long-ctx-16k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-5-world-v2-7B-0.4-300k/resolve/main/rwkv-v5-7B-0.4-long-ctx-16k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-5",
|
||||
"Global"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "rwkv-v5.2-7B-horror-16k.pth",
|
||||
"desc": {
|
||||
"en": "RWKV-5 Global Languages 7B v2 40% Ctx16k Horror",
|
||||
"zh": "RWKV-5 全球语言 7B v2 40% 16k上下文 恐怖",
|
||||
"ja": "RWKV-5 グローバル言語 7B v2 40% 16kコンテキスト ホラー"
|
||||
},
|
||||
"size": 15036198115,
|
||||
"SHA256": "3b36ce99bef06627dcb5d860972e2c1515327afe7db415b8c82dd5c3b926b52f",
|
||||
"lastUpdated": "2023-11-13T15:21:25",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-v5.2-7B-horror-16k/blob/main/rwkv-v5.2-7B-horror-16k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-v5.2-7B-horror-16k/resolve/main/rwkv-v5.2-7B-horror-16k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-5",
|
||||
"Global"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "rwkv_v5.2_7B_role_play_16k.pth",
|
||||
"desc": {
|
||||
"en": "RWKV-5 Global Languages 7B v2 Ctx16k Claude Like",
|
||||
"zh": "RWKV-5 全球语言 7B v2 16k上下文 Claude功能",
|
||||
"ja": "RWKV-5 グローバル言語 7B v2 16kコンテキスト Claude機能"
|
||||
},
|
||||
"size": 15036198115,
|
||||
"SHA256": "6fe8a7bf06b9f5e5b740cd87e24bff91325518ad19bf92bf5c75799b3c24b150",
|
||||
"lastUpdated": "2023-11-14T04:18:16",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-v5.2-7B-Role-play-16k/blob/main/rwkv_v5.2_7B_role_play_16k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-v5.2-7B-Role-play-16k/resolve/main/rwkv_v5.2_7B_role_play_16k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-5",
|
||||
"Global",
|
||||
"Role Play",
|
||||
"Recommended"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-12B-one-state-chat-16k.pth",
|
||||
"desc": {
|
||||
"en": "RWKV-5 Global Languages 12B Ctx16k",
|
||||
"zh": "RWKV-5 全球语言 12B 16k上下文",
|
||||
"ja": "RWKV-5 グローバル言語 12B 16kコンテキスト"
|
||||
},
|
||||
"size": 23157296483,
|
||||
"SHA256": "330be74738d3936f4c9bd6caf838db11c96f52ff360d0f4fa5401d9bafc898ab",
|
||||
"lastUpdated": "2023-12-16T16:34:30",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-v5-12B-one-state-chat-16k/blob/main/RWKV-5-12B-one-state-chat-16k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-v5-12B-one-state-chat-16k/resolve/main/RWKV-5-12B-one-state-chat-16k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-5",
|
||||
"Global",
|
||||
"Recommended"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-CHNtuned-0.1B-v1-20230617-ctx4096.pth",
|
||||
@@ -91,7 +236,12 @@
|
||||
"SHA256": "a3888f9958d378ee6d4976ae1c02edb698f4382e426086febafb4a69417b9080",
|
||||
"lastUpdated": "2023-06-17T18:35:26",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-CHNtuned-0.1B-v1-20230617-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-0.1B-v1-20230617-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-0.1B-v1-20230617-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-0.1B-v1-20230520-ctx4096.pth",
|
||||
@@ -104,7 +254,12 @@
|
||||
"SHA256": "a10ef99df2a8f8a6801edf4fc92a9c49bedd63dcb900d3e5667a2136b3d671e7",
|
||||
"lastUpdated": "2023-05-25T09:21:27",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-CHNtuned-0.4B-v1-20230618-ctx4096.pth",
|
||||
@@ -117,7 +272,12 @@
|
||||
"SHA256": "dbd5302cbee596bbc900f97eb10b2af3001a7f2c7e4d8643bf8683b2cdbdd324",
|
||||
"lastUpdated": "2023-06-18T10:46:50",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-CHNtuned-0.4B-v1-20230618-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-0.4B-v1-20230618-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-0.4B-v1-20230618-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-0.4B-v1-20230529-ctx4096.pth",
|
||||
@@ -130,7 +290,12 @@
|
||||
"SHA256": "4b4a2733cf5e5dc97dd62106f391d99895d16b11c5ccd10c89f28c52067a4919",
|
||||
"lastUpdated": "2023-05-29T13:25:53",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-0.4B-v1-20230529-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-0.4B-v1-20230529-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-0.4B-v1-20230529-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth",
|
||||
@@ -143,7 +308,12 @@
|
||||
"SHA256": "9f31f2ed5fe52dcf2d50208eb2efd764b9674dba2adb1baeff61997b4390a26b",
|
||||
"lastUpdated": "2023-06-20T06:35:37",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-1.5B-v1-OnlyForTest_57%_trained-20230529-ctx4096.pth",
|
||||
@@ -182,6 +352,11 @@
|
||||
"lastUpdated": "2023-06-07T09:33:32",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-1.5B-v1-20230607-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-1.5B-v1-20230607-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
],
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
@@ -195,7 +370,31 @@
|
||||
"SHA256": "71f0c3229f9227cbcb8ae5fee6461197129a57e26366c4d23a49058417b046c9",
|
||||
"lastUpdated": "2023-06-12T06:31:32",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-for-mobile-4-world-1.5B-20230906-ctx16k.pth",
|
||||
"desc": {
|
||||
"en": "Global Languages 1.5B v1 Ctx16k Claude Like",
|
||||
"zh": "全球语言 1.5B v1 16k上下文 Claude功能",
|
||||
"ja": "グローバル言語 1.5B v1 16kコンテキスト Claude機能"
|
||||
},
|
||||
"size": 3155280301,
|
||||
"SHA256": "20547a6deca32add57c45d2f6cff52c6b59cd3b92676ee369b964affba35619d",
|
||||
"lastUpdated": "2023-09-07T01:35:46",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-claude-for-mobile-v4-world-1.5B-16k/blob/main/RWKV-for-mobile-4-world-1.5B-20230906-ctx16k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-claude-for-mobile-v4-world-1.5B-16k/resolve/main/RWKV-for-mobile-4-world-1.5B-20230906-ctx16k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-3B-v1-OnlyForTest_35%_trained-20230529-ctx4096.pth",
|
||||
@@ -260,7 +459,12 @@
|
||||
"SHA256": "1b227af317fa25b6939ab3c7cd321226ca48b8fe4bbbd2df3db669f1482c54ba",
|
||||
"lastUpdated": "2023-06-20T03:00:51",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-3B-v1-20230619-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-3B-v1-20230619-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-3B-v1-20230619-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-CHNtuned-3B-v1-20230625-ctx4096.pth",
|
||||
@@ -273,7 +477,12 @@
|
||||
"SHA256": "7d3b5a4d0e9780a3e3d9ae7c2defbe8564d240bc9a238db4ba70cfb66dc33888",
|
||||
"lastUpdated": "2023-06-25T14:53:27",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-CHNtuned-3B-v1-20230625-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-3B-v1-20230625-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-3B-v1-20230625-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-7B-v1-OnlyForTest_30%_trained-20230529-ctx4096.pth",
|
||||
@@ -364,7 +573,12 @@
|
||||
"SHA256": "db7b011247a0fe4389e1d76e3d6a904185f85d509c8a44ad18bf401094efc293",
|
||||
"lastUpdated": "2023-06-26T16:40:04",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-7B-v1-20230626-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-7B-v1-20230626-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-7B-v1-20230626-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-claude-4-World-7B-20230805-ctx65k.pth",
|
||||
@@ -377,7 +591,12 @@
|
||||
"SHA256": "8cd25f8a1ab58965993cc47b3b2f99585836eed008a2e44526c258189ea751a6",
|
||||
"lastUpdated": "2023-08-05T08:52:20",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-claude-4-World-7B-65k/blob/main/RWKV-claude-4-World-7B-20230805-ctx65k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-claude-4-World-7B-65k/resolve/main/RWKV-claude-4-World-7B-20230805-ctx65k.pth"
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-claude-4-World-7B-65k/resolve/main/RWKV-claude-4-World-7B-20230805-ctx65k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-toolformer-translation-japanese-chinese-english-7B-World-20230815-ctx128k.pth",
|
||||
@@ -390,7 +609,12 @@
|
||||
"SHA256": "648a3b21055bdab77021ce278da80fbada8dcaae0b3d41d1eca9aa194c1fd25f",
|
||||
"lastUpdated": "2023-08-15T07:18:23",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-toolformer-translation-japanese-chinese-english-7B-World-128k/blob/main/RWKV-toolformer-translation-japanese-chinese-english-7B-World-20230815-ctx128k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-toolformer-translation-japanese-chinese-english-7B-World-128k/resolve/main/RWKV-toolformer-translation-japanese-chinese-english-7B-World-20230815-ctx128k.pth"
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-toolformer-translation-japanese-chinese-english-7B-World-128k/resolve/main/RWKV-toolformer-translation-japanese-chinese-english-7B-World-20230815-ctx128k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-code-4-World-7B-20230820-ctx32k.pth",
|
||||
@@ -403,7 +627,12 @@
|
||||
"SHA256": "19666620437ae3a5fb06e16a52729d67e449fca155fab3d5861ffe9ecf247404",
|
||||
"lastUpdated": "2023-08-20T05:00:17",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-Code-7B-world-32k/blob/main/RWKV-code-4-World-7B-20230820-ctx32k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-Code-7B-world-32k/resolve/main/RWKV-code-4-World-7B-20230820-ctx32k.pth"
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-Code-7B-world-32k/resolve/main/RWKV-code-4-World-7B-20230820-ctx32k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "wizard-rwkv-4-world-ctx32k.pth",
|
||||
@@ -416,7 +645,89 @@
|
||||
"SHA256": "c5d991f315a1676d4bed93dd91f803b1376096e7a4af5bf72b339d055f53bac7",
|
||||
"lastUpdated": "2023-07-29T03:21:47",
|
||||
"url": "https://huggingface.co/xiaol/wizard-rwkv-world-7B-ctx32k/blob/main/wizard-rwkv-4-world-ctx32k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/wizard-rwkv-world-7B-ctx32k/resolve/main/wizard-rwkv-4-world-ctx32k.pth"
|
||||
"downloadUrl": "https://huggingface.co/xiaol/wizard-rwkv-world-7B-ctx32k/resolve/main/wizard-rwkv-4-world-ctx32k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-7B-world-one.pth",
|
||||
"desc": {
|
||||
"en": "Global Languages 7B v1 Ctx65k Novel",
|
||||
"zh": "全球语言 7B v1 65k上下文 小说",
|
||||
"ja": "グローバル言語 7B v1 65kコンテキスト 小説"
|
||||
},
|
||||
"size": 15035391533,
|
||||
"SHA256": "7ce95a4b460c3385c75c29b6ebe3cd7db438b1107e85d7d3e42dff85cfaa0b78",
|
||||
"lastUpdated": "2023-10-09T05:23:38",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-v4-world-7B-one-state-65k/blob/main/RWKV-7B-world-one.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-v4-world-7B-one-state-65k/resolve/main/RWKV-7B-world-one.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "rwkv-world-one-novel-cot-ultrachat-novel-instructions.pth",
|
||||
"desc": {
|
||||
"en": "Global Languages 7B v1 Ctx65k Novel Instruction",
|
||||
"zh": "全球语言 7B v1 65k上下文 小说指令",
|
||||
"ja": "グローバル言語 7B v1 65kコンテキスト 小説指示"
|
||||
},
|
||||
"size": 15035391533,
|
||||
"SHA256": "fc2d4643828bb9dfe0733c3b2eb54ba2d996ed3eb6afa051b558da2eb2c1e309",
|
||||
"lastUpdated": "2023-10-22T09:50:39",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-4-world-one-state-ultrachat-COT-65k/blob/main/rwkv-world-one-novel-cot-ultrachat-novel-instructions.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-4-world-one-state-ultrachat-COT-65k/resolve/main/rwkv-world-one-novel-cot-ultrachat-novel-instructions.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "RWKV-world-novel-one-state-ultrachat-cot-tuned-Role-play-65k.pth",
|
||||
"desc": {
|
||||
"en": "Global Languages 7B v1 Ctx65k Role Play",
|
||||
"zh": "全球语言 7B v1 65k上下文 角色扮演",
|
||||
"ja": "グローバル言語 7B v1 65kコンテキスト ロールプレイ"
|
||||
},
|
||||
"size": 15035391533,
|
||||
"SHA256": "2f55b4710dcd360e83b4df9a6358661284d9a6c6108f62c5a30b86df181ed67a",
|
||||
"lastUpdated": "2023-10-22T05:54:27",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-4-world-one-state-ultrachat-COT-65k/blob/main/RWKV-world-novel-one-state-ultrachat-cot-tuned-Role-play-65k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-4-world-one-state-ultrachat-COT-65k/resolve/main/RWKV-world-novel-one-state-ultrachat-cot-tuned-Role-play-65k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"Global",
|
||||
"Role Play"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-7B-world-one-novel-tuned-65k.pth",
|
||||
"desc": {
|
||||
"en": "Global Languages 7B v1 Ctx65k Chinese Novel Instruction",
|
||||
"zh": "全球语言 7B v1 65k上下文 中文小说指令",
|
||||
"ja": "グローバル言語 7B v1 65kコンテキスト 中国語小説指示"
|
||||
},
|
||||
"size": 15035391533,
|
||||
"SHA256": "e8ff256d74ca404621dcbf87c43c37e25ea745fed30c404fbf45cc5acc7ba2b5",
|
||||
"lastUpdated": "2023-10-15T00:57:53",
|
||||
"url": "https://huggingface.co/xiaol/RWKV-4-world-one-state-novel-tuned-65k/blob/main/RWKV-4-7B-world-one-novel-tuned-65k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/RWKV-4-world-one-state-novel-tuned-65k/resolve/main/RWKV-4-7B-world-one-novel-tuned-65k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
],
|
||||
"customTokenizer": "backend-python/rwkv_pip/rwkv_vocab_v20230424_special_token.txt"
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-CHNtuned-7B-v1-20230709-ctx4096.pth",
|
||||
@@ -429,7 +740,12 @@
|
||||
"SHA256": "52d33e8352a40158d21425fee4f68df1515d6324056f788d2c78a366ef578ffa",
|
||||
"lastUpdated": "2023-07-09T18:23:33",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-CHNtuned-7B-v1-20230709-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-7B-v1-20230709-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-CHNtuned-7B-v1-20230709-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Readflow-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth",
|
||||
@@ -442,7 +758,12 @@
|
||||
"SHA256": "1bd1de8cdbd56b67e1374588fe5d202884049c71278ffcb12f5c4efbdb422ee1",
|
||||
"lastUpdated": "2023-07-20T06:11:29",
|
||||
"url": "https://huggingface.co/xiaol/readflow-rwkv-4-world-ctx32k/blob/main/Readflow-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/readflow-rwkv-4-world-ctx32k/resolve/main/Readflow-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth"
|
||||
"downloadUrl": "https://huggingface.co/xiaol/readflow-rwkv-4-world-ctx32k/resolve/main/Readflow-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "novel-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth",
|
||||
@@ -455,7 +776,12 @@
|
||||
"SHA256": "0fe2415ce61af52a8c38c071b475c01b4c9f8a4f2b4aaed6181f0334f3faf7f4",
|
||||
"lastUpdated": "2023-07-28T13:30:59",
|
||||
"url": "https://huggingface.co/xiaol/ruotangwx-rwkv-7b-novel-32k/blob/main/novel-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/ruotangwx-rwkv-7b-novel-32k/resolve/main/novel-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth"
|
||||
"downloadUrl": "https://huggingface.co/xiaol/ruotangwx-rwkv-7b-novel-32k/resolve/main/novel-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "chatgal-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k-1000.pth",
|
||||
@@ -468,7 +794,12 @@
|
||||
"SHA256": "aaed29cfd1bddee47c48f564aa800eb001f62fd03290d772647d5678e40d66e8",
|
||||
"lastUpdated": "2023-07-21T08:59:18",
|
||||
"url": "https://huggingface.co/xiaol/chatgal-rwkv-7b-world-32k/blob/main/chatgal-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k-1000.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/chatgal-rwkv-7b-world-32k/resolve/main/chatgal-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k-1000.pth"
|
||||
"downloadUrl": "https://huggingface.co/xiaol/chatgal-rwkv-7b-world-32k/resolve/main/chatgal-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k-1000.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "chatgal-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k-500.pth",
|
||||
@@ -481,7 +812,12 @@
|
||||
"SHA256": "b5d347d5dedb4f398ec31489ab87b75b1dee772ae7d0a34c26635cf5d95c8794",
|
||||
"lastUpdated": "2023-07-21T07:31:05",
|
||||
"url": "https://huggingface.co/xiaol/chatgal-rwkv-7b-world-32k/blob/main/chatgal-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k-500.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/chatgal-rwkv-7b-world-32k/resolve/main/chatgal-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k-500.pth"
|
||||
"downloadUrl": "https://huggingface.co/xiaol/chatgal-rwkv-7b-world-32k/resolve/main/chatgal-RWKV-4-World-CHNtuned-7B-v1-20230709-ctx32k-500.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-World-JPNtuned-7B-v1-20230718-ctx4096.pth",
|
||||
@@ -494,7 +830,12 @@
|
||||
"SHA256": "3e4c7664ce893ac1f6bb59cd76664fb5c872cb076bb82dbd534db0555b6e9fa5",
|
||||
"lastUpdated": "2023-07-18T20:01:12",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-world/blob/main/RWKV-4-World-JPNtuned-7B-v1-20230718-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-JPNtuned-7B-v1-20230718-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-JPNtuned-7B-v1-20230718-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"JP"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-novel-4-World-7B-20230810-ctx128k.pth",
|
||||
@@ -507,7 +848,12 @@
|
||||
"SHA256": "5e429c49e4cab2f29a93f87a80635422c8710d70e5b1d962c078e47d957389c8",
|
||||
"lastUpdated": "2023-08-10T06:30:32",
|
||||
"url": "https://huggingface.co/xiaol/rwkv-7B-world-novel-128k/blob/main/RWKV-novel-4-World-7B-20230810-ctx128k.pth",
|
||||
"downloadUrl": "https://huggingface.co/xiaol/rwkv-7B-world-novel-128k/resolve/main/RWKV-novel-4-World-7B-20230810-ctx128k.pth"
|
||||
"downloadUrl": "https://huggingface.co/xiaol/rwkv-7B-world-novel-128k/resolve/main/RWKV-novel-4-World-7B-20230810-ctx128k.pth",
|
||||
"tags": [
|
||||
"Finetuned",
|
||||
"RWKV-4",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Novel-7B-v1-ChnEng-ChnPro-20230410-ctx4096.pth",
|
||||
@@ -519,7 +865,13 @@
|
||||
"SHA256": "cd40b661930dea46c0f930c51d99cef6b484fe3d641388981dee5a0c68e2b1c7",
|
||||
"lastUpdated": "2023-04-10T13:55:52",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-novel/blob/main/RWKV-4-Novel-7B-v1-ChnEng-ChnPro-20230410-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-7B-v1-ChnEng-ChnPro-20230410-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-7B-v1-ChnEng-ChnPro-20230410-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Novel-3B-v1-ChnEng-20230412-ctx4096.pth",
|
||||
@@ -531,7 +883,13 @@
|
||||
"SHA256": "283c6e6fa10c52a93e9a01d9630f288473267ea152a49c6579b5c0427bdc9c61",
|
||||
"lastUpdated": "2023-04-12T13:18:29",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-novel/blob/main/RWKV-4-Novel-3B-v1-ChnEng-20230412-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-3B-v1-ChnEng-20230412-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-3B-v1-ChnEng-20230412-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Novel-7B-v1-ChnEng-20230426-ctx8192.pth",
|
||||
@@ -543,7 +901,13 @@
|
||||
"SHA256": "bd08c75a296bd193dcfadb993fe06d7f9dd91ca3385231f24c592c89d25cd596",
|
||||
"lastUpdated": "2023-04-26T18:57:01",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-novel/blob/main/RWKV-4-Novel-7B-v1-ChnEng-20230426-ctx8192.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-7B-v1-ChnEng-20230426-ctx8192.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-7B-v1-ChnEng-20230426-ctx8192.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Novel-3B-v1-Chn-20230412-ctx4096.pth",
|
||||
@@ -555,7 +919,13 @@
|
||||
"SHA256": "c41e0af2cbc66e94121377680e8224a1504fac6c9ea620c395f0a79281db26e7",
|
||||
"lastUpdated": "2023-04-12T13:18:29",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-novel/blob/main/RWKV-4-Novel-3B-v1-Chn-20230412-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-3B-v1-Chn-20230412-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-3B-v1-Chn-20230412-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Novel-7B-v1-Chn-20230426-ctx8192.pth",
|
||||
@@ -567,7 +937,13 @@
|
||||
"SHA256": "5fced44febdf80d303250eef9c020f087abded43aaecc8caaea8a9e7f1fb771e",
|
||||
"lastUpdated": "2023-04-26T18:57:01",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-novel/blob/main/RWKV-4-Novel-7B-v1-Chn-20230426-ctx8192.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-7B-v1-Chn-20230426-ctx8192.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-novel/resolve/main/RWKV-4-Novel-7B-v1-Chn-20230426-ctx8192.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"Global"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Raven-1B5-v11-Eng99%-Other1%-20230425-ctx4096.pth",
|
||||
@@ -580,6 +956,11 @@
|
||||
"lastUpdated": "2023-04-26T14:27:55",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-1B5-v11-Eng99%25-Other1%25-20230425-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-1B5-v11-Eng99%25-Other1%25-20230425-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven"
|
||||
],
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
@@ -592,7 +973,12 @@
|
||||
"SHA256": "6bbbffb3ee2372dfa9ef49c599e9a2bc0a01b94b6a264ba9bf5bd524fc38f723",
|
||||
"lastUpdated": "2023-05-21T07:08:56",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-1B5-v12-Eng98%25-Other2%25-20230520-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-1B5-v12-Eng98%25-Other2%25-20230520-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-1B5-v12-Eng98%25-Other2%25-20230520-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Raven-3B-v11-Eng99%-Other1%-20230425-ctx4096.pth",
|
||||
@@ -605,6 +991,11 @@
|
||||
"lastUpdated": "2023-04-26T14:27:55",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-3B-v11-Eng99%25-Other1%25-20230425-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-3B-v11-Eng99%25-Other1%25-20230425-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven"
|
||||
],
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
@@ -617,7 +1008,12 @@
|
||||
"SHA256": "1eea1845acfe9729dfdaec66a8d1aeb91a1287d94bebbca5529c13c050540b33",
|
||||
"lastUpdated": "2023-05-21T07:13:25",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-3B-v12-Eng98%25-Other2%25-20230520-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-3B-v12-Eng98%25-Other2%25-20230520-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-3B-v12-Eng98%25-Other2%25-20230520-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Raven-3B-v11-Eng49%-Chn49%-Jpn1%-Other1%-20230429-ctx4096.pth",
|
||||
@@ -630,6 +1026,12 @@
|
||||
"lastUpdated": "2023-04-29T11:51:51",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-3B-v11-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230429-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-3B-v11-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230429-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"CN"
|
||||
],
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
@@ -642,7 +1044,13 @@
|
||||
"SHA256": "c0abb4b745ba3523b9d8b3e1293110867ee55b1ef3dc8c122212f78396755721",
|
||||
"lastUpdated": "2023-05-28T11:51:12",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-3B-v12-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230527-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-3B-v12-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230527-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-3B-v12-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230527-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth",
|
||||
@@ -655,6 +1063,11 @@
|
||||
"lastUpdated": "2023-04-29T11:44:32",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-7B-v11x-Eng99%25-Other1%25-20230429-ctx8192.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-7B-v11x-Eng99%25-Other1%25-20230429-ctx8192.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven"
|
||||
],
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
@@ -667,7 +1080,12 @@
|
||||
"SHA256": "5a725eaeb9e09b724de6c97e6845dd0283097c7920acd05b46852ab7afa9ec32",
|
||||
"lastUpdated": "2023-05-22T10:32:17",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-7B-v12-Eng98%25-Other2%25-20230521-ctx8192.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-7B-v12-Eng98%25-Other2%25-20230521-ctx8192.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-7B-v12-Eng98%25-Other2%25-20230521-ctx8192.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Raven-7B-v10x-Eng49%-Chn50%-Other1%-20230423-ctx4096.pth",
|
||||
@@ -680,6 +1098,12 @@
|
||||
"lastUpdated": "2023-04-24T07:48:55",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-7B-v10x-Eng49%25-Chn50%25-Other1%25-20230423-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-7B-v10x-Eng49%25-Chn50%25-Other1%25-20230423-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"CN"
|
||||
],
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
@@ -693,6 +1117,12 @@
|
||||
"lastUpdated": "2023-04-30T14:35:59",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-7B-v11-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230430-ctx8192.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-7B-v11-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230430-ctx8192.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"CN"
|
||||
],
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
@@ -705,7 +1135,13 @@
|
||||
"SHA256": "6d4a089ff36d5d9d96b669d425fc5e4e3959cab426535b52e2364df08f58b407",
|
||||
"lastUpdated": "2023-05-30T23:16:12",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-7B-v12-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230530-ctx8192.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-7B-v12-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230530-ctx8192.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-7B-v12-Eng49%25-Chn49%25-Jpn1%25-Other1%25-20230530-ctx8192.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven",
|
||||
"CN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-Raven-14B-v11x-Eng99%-Other1%-20230501-ctx8192.pth",
|
||||
@@ -718,6 +1154,11 @@
|
||||
"lastUpdated": "2023-05-02T09:43:33",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-14B-v11x-Eng99%25-Other1%25-20230501-ctx8192.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-14B-v11x-Eng99%25-Other1%25-20230501-ctx8192.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven"
|
||||
],
|
||||
"hide": true
|
||||
},
|
||||
{
|
||||
@@ -730,7 +1171,12 @@
|
||||
"SHA256": "1193b5a9ceab572e4dbb9ed1d798eab7bf4793d18904d08bd4bf183579338ae7",
|
||||
"lastUpdated": "2023-05-23T11:22:41",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-raven/blob/main/RWKV-4-Raven-14B-v12-Eng98%25-Other2%25-20230523-ctx8192.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-14B-v12-Eng98%25-Other2%25-20230523-ctx8192.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-14B-v12-Eng98%25-Other2%25-20230523-ctx8192.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Raven"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth",
|
||||
@@ -743,7 +1189,12 @@
|
||||
"SHA256": "161d27dcf50d0958d230601ba1e0f8e7dd9c236105e92d2b833496412ace430c",
|
||||
"lastUpdated": "2023-07-15T08:03:36",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-music/blob/main/RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-music/resolve/main/RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-music/resolve/main/RWKV-4-MIDI-120M-v1-20230714-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Music"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-MIDI-560M-v1-20230717-ctx4096.pth",
|
||||
@@ -756,7 +1207,30 @@
|
||||
"SHA256": "62b21841b24af38ef176e9e9d895d9fff730cea8aa0623f53a1784d74ce828d6",
|
||||
"lastUpdated": "2023-07-17T15:02:08",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-music/blob/main/RWKV-4-MIDI-560M-v1-20230717-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-music/resolve/main/RWKV-4-MIDI-560M-v1-20230717-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-music/resolve/main/RWKV-4-MIDI-560M-v1-20230717-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Music"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-4-ABC-82M-v1-20230805-ctx1024.pth",
|
||||
"desc": {
|
||||
"en": "Music ABC 82M v1",
|
||||
"zh": "作曲 ABC 82M v1",
|
||||
"ja": "作曲 ABC 82M v1"
|
||||
},
|
||||
"size": 164183345,
|
||||
"SHA256": "4c83859f387bc3953d19890338a3e50ea7f2278e1bbb9d6eae9b773c81958a01",
|
||||
"lastUpdated": "2023-08-06T05:46:55",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-4-music/blob/main/RWKV-4-ABC-82M-v1-20230805-ctx1024.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-4-music/resolve/main/RWKV-4-ABC-82M-v1-20230805-ctx1024.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-4",
|
||||
"Music"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-MIDI-120M-v1-20230728-ctx4096.pth",
|
||||
@@ -769,7 +1243,12 @@
|
||||
"SHA256": "c43d4a2ee7a71a331d05d6cd818dd75f7c48c716e4b98c58e4d27231614b0144",
|
||||
"lastUpdated": "2023-07-29T02:17:27",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-music/blob/main/RWKV-5-MIDI-120M-v1-20230728-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-music/resolve/main/RWKV-5-MIDI-120M-v1-20230728-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-music/resolve/main/RWKV-5-MIDI-120M-v1-20230728-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-5",
|
||||
"Music"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-MIDI-560M-v1-20230902-ctx4096.pth",
|
||||
@@ -782,7 +1261,30 @@
|
||||
"SHA256": "cb4f2fd8956ca8496d6b2e33bff290c2047759b6fe74884903dbf9c73a11cc77",
|
||||
"lastUpdated": "2023-09-03T04:48:41",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-music/blob/main/RWKV-5-MIDI-560M-v1-20230902-ctx4096.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-music/resolve/main/RWKV-5-MIDI-560M-v1-20230902-ctx4096.pth"
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-music/resolve/main/RWKV-5-MIDI-560M-v1-20230902-ctx4096.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-5",
|
||||
"Music"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "RWKV-5-ABC-82M-v1-20230901-ctx1024.pth",
|
||||
"desc": {
|
||||
"en": "RWKV-5 Music ABC 82M v1",
|
||||
"zh": "RWKV-5 作曲 ABC 82M v1",
|
||||
"ja": "RWKV-5 作曲 ABC 82M v1"
|
||||
},
|
||||
"size": 164222002,
|
||||
"SHA256": "5bf9ae32e4ef05c3851d6010709c6c00dda926d110766b9a712bc48c0a53e098",
|
||||
"lastUpdated": "2023-09-02T06:55:12",
|
||||
"url": "https://huggingface.co/BlinkDL/rwkv-5-music/blob/main/RWKV-5-ABC-82M-v1-20230901-ctx1024.pth",
|
||||
"downloadUrl": "https://huggingface.co/BlinkDL/rwkv-5-music/resolve/main/RWKV-5-ABC-82M-v1-20230901-ctx1024.pth",
|
||||
"tags": [
|
||||
"Main",
|
||||
"RWKV-5",
|
||||
"Music"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
- ^backend-python/get-pip\.py
|
||||
- ^backend-python/convert_model\.py
|
||||
- ^backend-python/convert_safetensors\.py
|
||||
- ^backend-python/convert_pytorch_to_ggml\.py linguist-vendored
|
||||
- ^backend-python/utils/midi\.py
|
||||
- ^build/
|
||||
- ^finetune/lora/
|
||||
|
||||
Reference in New Issue
Block a user