add support for MIDI RWKV

This commit is contained in:
josc146
2023-07-25 16:09:31 +08:00
parent 211ae342af
commit 05b9b42b56
9 changed files with 20373 additions and 100 deletions

View File

@@ -13,13 +13,16 @@ router = APIRouter()
def get_tokens_path(model_path: str):
model_path = model_path.lower()
default_tokens_path = (
f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
)
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
default_tokens_path = tokenizer_dir + "20B_tokenizer.json"
if "raven" in model_path:
return default_tokens_path
elif "world" in model_path:
return "rwkv_vocab_v20230424"
elif "midi" in model_path:
return tokenizer_dir + "tokenizer-midi.json"
else:
return default_tokens_path
@@ -66,7 +69,13 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
try:
global_var.set(
global_var.Model,
RWKV(
TextRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=get_tokens_path(body.model),
)
if "midi" not in body.model.lower()
else MusicRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=get_tokens_path(body.model),