This commit is contained in:
josc146 2023-07-26 22:24:26 +08:00
parent 1df345b5eb
commit d0fd480bd6
9 changed files with 78 additions and 24 deletions

1
.gitattributes vendored
View File

@ -2,6 +2,7 @@ backend-python/rwkv_pip/** linguist-vendored
backend-python/wkv_cuda_utils/** linguist-vendored backend-python/wkv_cuda_utils/** linguist-vendored
backend-python/get-pip.py linguist-vendored backend-python/get-pip.py linguist-vendored
backend-python/convert_model.py linguist-vendored backend-python/convert_model.py linguist-vendored
backend-python/utils/midi.py linguist-vendored
build/** linguist-vendored build/** linguist-vendored
finetune/lora/** linguist-vendored finetune/lora/** linguist-vendored
finetune/json2binidx_tool/** linguist-vendored finetune/json2binidx_tool/** linguist-vendored

1
.gitignore vendored
View File

@ -23,3 +23,4 @@ __pycache__
*.log *.log
train_log.txt train_log.txt
finetune/json2binidx_tool/data finetune/json2binidx_tool/data
/wsl.state

View File

@ -42,12 +42,12 @@ def init():
ngrok_connect() ngrok_connect()
@app.get("/") @app.get("/", tags=["Root"])
def read_root(): def read_root():
return {"Hello": "World!"} return {"Hello": "World!"}
@app.post("/exit") @app.post("/exit", tags=["Root"])
def exit(): def exit():
parent_pid = os.getpid() parent_pid = os.getpid()
parent = psutil.Process(parent_pid) parent = psutil.Process(parent_pid)

View File

@ -206,8 +206,8 @@ async def eval_rwkv(
} }
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions", tags=["Completions"])
@router.post("/chat/completions") @router.post("/chat/completions", tags=["Completions"])
async def chat_completions(body: ChatCompletionBody, request: Request): async def chat_completions(body: ChatCompletionBody, request: Request):
model: TextRWKV = global_var.get(global_var.Model) model: TextRWKV = global_var.get(global_var.Model)
if model is None: if model is None:
@ -299,8 +299,8 @@ The following is a coherent verbose detailed conversation between a girl named {
return None return None
@router.post("/v1/completions") @router.post("/v1/completions", tags=["Completions"])
@router.post("/completions") @router.post("/completions", tags=["Completions"])
async def completions(body: CompletionBody, request: Request): async def completions(body: CompletionBody, request: Request):
model: AbstractRWKV = global_var.get(global_var.Model) model: AbstractRWKV = global_var.get(global_var.Model)
if model is None: if model is None:
@ -346,10 +346,10 @@ def embedding_base64(embedding: List[float]) -> str:
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8") return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
@router.post("/v1/embeddings") @router.post("/v1/embeddings", tags=["Embeddings"])
@router.post("/embeddings") @router.post("/embeddings", tags=["Embeddings"])
@router.post("/v1/engines/text-embedding-ada-002/embeddings") @router.post("/v1/engines/text-embedding-ada-002/embeddings", tags=["Embeddings"])
@router.post("/engines/text-embedding-ada-002/embeddings") @router.post("/engines/text-embedding-ada-002/embeddings", tags=["Embeddings"])
async def embeddings(body: EmbeddingsBody, request: Request): async def embeddings(body: EmbeddingsBody, request: Request):
model: AbstractRWKV = global_var.get(global_var.Model) model: AbstractRWKV = global_var.get(global_var.Model)
if model is None: if model is None:

View File

@ -42,7 +42,7 @@ class SwitchModelBody(BaseModel):
} }
@router.post("/switch-model") @router.post("/switch-model", tags=["Configs"])
def switch_model(body: SwitchModelBody, response: Response, request: Request): def switch_model(body: SwitchModelBody, response: Response, request: Request):
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading: if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
response.status_code = Status.HTTP_304_NOT_MODIFIED response.status_code = Status.HTTP_304_NOT_MODIFIED
@ -98,7 +98,7 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
return "success" return "success"
@router.post("/update-config") @router.post("/update-config", tags=["Configs"])
def update_config(body: ModelConfigBody): def update_config(body: ModelConfigBody):
""" """
Will not update the model config immediately, but set it when completion called to avoid modifications during generation Will not update the model config immediately, but set it when completion called to avoid modifications during generation
@ -110,7 +110,7 @@ def update_config(body: ModelConfigBody):
return "success" return "success"
@router.get("/status") @router.get("/status", tags=["Configs"])
def status(): def status():
gpus = GPUtil.getGPUs() gpus = GPUtil.getGPUs()
if len(gpus) == 0: if len(gpus) == 0:

View File

@ -1,4 +1,6 @@
import io
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException, status
from starlette.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from utils.midi import * from utils.midi import *
from midi2audio import FluidSynth from midi2audio import FluidSynth
@ -6,6 +8,29 @@ from midi2audio import FluidSynth
router = APIRouter() router = APIRouter()
class TextToMidiBody(BaseModel):
text: str
class Config:
schema_extra = {
"example": {
"text": "p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:2e:a p:3b:a p:45:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2e:0 p:3b:0 p:45:0 b:26:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:26:a g:3e:a g:3e:a g:42:a g:42:a g:45:a g:45:a pi:3e:a pi:42:a pi:45:a t14 p:2a:0 p:3b:0 p:45:0 b:26:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:2d:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:2d:0 g:3e:0 g:3e:0 g:42:0 g:42:0 g:45:0 g:45:0 pi:3e:0 pi:42:0 pi:45:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2a:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a g:39:a g:39:a g:3d:a g:3d:a g:40:a g:40:a pi:39:a pi:3d:a pi:40:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 t2 p:26:a p:2e:a p:31:a p:39:a p:3b:a p:45:a b:21:a t14 p:26:0 p:2e:0 p:31:0 p:39:0 p:3b:0 p:45:0 b:21:0 g:39:0 g:39:0 g:3d:0 g:3d:0 g:40:0 g:40:0 pi:39:0 pi:3d:0 pi:40:0 t2 p:24:a p:2a:a p:31:a p:39:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:24:0 p:2a:0 p:31:0 p:39:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:2e:a p:3b:a p:45:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2e:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:26:a p:2a:a p:3b:a p:45:a t14 p:26:0 p:2a:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a b:1f:a g:3b:a g:3b:a g:3e:a g:3e:a g:43:a g:43:a pi:3b:a pi:3e:a pi:43:a t14 p:2a:0 p:3b:0 p:45:0 b:1f:0 t2 p:24:a p:2a:a p:3b:a p:45:a b:1f:a t14 p:24:0 p:2a:0 p:3b:0 p:45:0 b:1f:0 g:3b:0 g:3b:0 g:3e:0 g:3e:0 g:43:0 g:43:0 pi:3b:0 pi:3e:0 pi:43:0 t2 p:24:a p:2e:a p:3b:a p:45:a b:26:a g:39:a g:39:a g:3e:a g:3e:a g:42:a g:42:a pi:39:a pi:3e:a pi:42:a t14 p:24:0 p:2e:0 p:3b:0 p:45:0 t2 p:2a:a p:3b:a p:45:a t14 p:2a:0 p:3b:0",
}
}
@router.post("/text-to-midi", tags=["MIDI"])
def text_to_midi(body: TextToMidiBody):
vocab_config = "backend-python/utils/midi_vocab_config.json"
cfg = VocabConfig.from_json(vocab_config)
mid = convert_str_to_midi(cfg, body.text.strip())
mid_data = io.BytesIO()
mid.save(None, mid_data)
mid_data.seek(0)
return StreamingResponse(mid_data, media_type="audio/midi")
class TxtToMidiBody(BaseModel): class TxtToMidiBody(BaseModel):
txt_path: str txt_path: str
midi_path: str midi_path: str
@ -19,7 +44,7 @@ class TxtToMidiBody(BaseModel):
} }
@router.post("/txt-to-midi") @router.post("/txt-to-midi", tags=["MIDI"])
def txt_to_midi(body: TxtToMidiBody): def txt_to_midi(body: TxtToMidiBody):
if not body.midi_path.startswith("midi/"): if not body.midi_path.startswith("midi/"):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path") raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
@ -32,6 +57,8 @@ def txt_to_midi(body: TxtToMidiBody):
mid = convert_str_to_midi(cfg, text) mid = convert_str_to_midi(cfg, text)
mid.save(body.midi_path) mid.save(body.midi_path)
return "success"
class MidiToWavBody(BaseModel): class MidiToWavBody(BaseModel):
midi_path: str midi_path: str
@ -48,15 +75,20 @@ class MidiToWavBody(BaseModel):
} }
# install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions @router.post("/midi-to-wav", tags=["MIDI"])
@router.post("/midi-to-wav")
def midi_to_wav(body: MidiToWavBody): def midi_to_wav(body: MidiToWavBody):
"""
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
"""
if not body.wav_path.startswith("midi/"): if not body.wav_path.startswith("midi/"):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path") raise HTTPException(status.HTTP_400_BAD_REQUEST, "bad output path")
fs = FluidSynth(body.sound_font_path) fs = FluidSynth(body.sound_font_path)
fs.midi_to_audio(body.midi_path, body.wav_path) fs.midi_to_audio(body.midi_path, body.wav_path)
return "success"
class TextToWavBody(BaseModel): class TextToWavBody(BaseModel):
text: str text: str
@ -73,9 +105,12 @@ class TextToWavBody(BaseModel):
} }
# install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions @router.post("/text-to-wav", tags=["MIDI"])
@router.post("/text-to-wav")
def text_to_wav(body: TextToWavBody): def text_to_wav(body: TextToWavBody):
"""
Install fluidsynth first, see more: https://github.com/FluidSynth/fluidsynth/wiki/Download#distributions
"""
text = body.text.strip() text = body.text.strip()
if not text.startswith("<start>"): if not text.startswith("<start>"):
text = "<start> " + text text = "<start> " + text
@ -92,3 +127,5 @@ def text_to_wav(body: TextToWavBody):
midi_path=midi_path, wav_path=wav_path, sound_font_path=body.sound_font_path midi_path=midi_path, wav_path=wav_path, sound_font_path=body.sound_font_path
) )
) )
return "success"

View File

@ -32,7 +32,7 @@ def init():
print("cyac not found") print("cyac not found")
@router.post("/disable-state-cache") @router.post("/disable-state-cache", tags=["State Cache"])
def disable_state_cache(): def disable_state_cache():
global trie, dtrie global trie, dtrie
@ -43,7 +43,7 @@ def disable_state_cache():
return "success" return "success"
@router.post("/enable-state-cache") @router.post("/enable-state-cache", tags=["State Cache"])
def enable_state_cache(): def enable_state_cache():
global trie, dtrie global trie, dtrie
try: try:
@ -65,7 +65,7 @@ class AddStateBody(BaseModel):
logits: Any logits: Any
@router.post("/add-state") @router.post("/add-state", tags=["State Cache"])
def add_state(body: AddStateBody): def add_state(body: AddStateBody):
global trie, dtrie, loop_del_trie_id global trie, dtrie, loop_del_trie_id
if trie is None: if trie is None:
@ -105,7 +105,7 @@ def add_state(body: AddStateBody):
) )
@router.post("/reset-state") @router.post("/reset-state", tags=["State Cache"])
def reset_state(): def reset_state():
global trie, dtrie global trie, dtrie
if trie is None: if trie is None:
@ -141,7 +141,7 @@ def _get_a_dtrie_buff_size(dtrie_v):
return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 # TODO
@router.post("/longest-prefix-state") @router.post("/longest-prefix-state", tags=["State Cache"])
def longest_prefix_state(body: LongestPrefixStateBody, request: Request): def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
global trie global trie
if trie is None: if trie is None:
@ -180,7 +180,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
} }
@router.post("/save-state") @router.post("/save-state", tags=["State Cache"])
def save_state(): def save_state():
global trie global trie
if trie is None: if trie is None:

View File

@ -1,3 +1,15 @@
# https://github.com/briansemrau/MIDI-LLM-tokenizer
# MIT License
# Copyright (c) 2023 Brian Semrau
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import json import json
import random import random
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -2,5 +2,8 @@
- ^backend-python/wkv_cuda_utils/ - ^backend-python/wkv_cuda_utils/
- ^backend-python/get-pip\.py - ^backend-python/get-pip\.py
- ^backend-python/convert_model\.py - ^backend-python/convert_model\.py
- ^backend-python/utils/midi\.py
- ^build/ - ^build/
- ^finetune/lora/
- ^finetune/json2binidx_tool/
- ^frontend/wailsjs/ - ^frontend/wailsjs/