add support for MIDI RWKV
This commit is contained in:
@@ -13,13 +13,16 @@ router = APIRouter()
|
||||
|
||||
def get_tokens_path(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
default_tokens_path = (
|
||||
f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
|
||||
)
|
||||
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
|
||||
|
||||
default_tokens_path = tokenizer_dir + "20B_tokenizer.json"
|
||||
|
||||
if "raven" in model_path:
|
||||
return default_tokens_path
|
||||
elif "world" in model_path:
|
||||
return "rwkv_vocab_v20230424"
|
||||
elif "midi" in model_path:
|
||||
return tokenizer_dir + "tokenizer-midi.json"
|
||||
else:
|
||||
return default_tokens_path
|
||||
|
||||
@@ -66,7 +69,13 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||
try:
|
||||
global_var.set(
|
||||
global_var.Model,
|
||||
RWKV(
|
||||
TextRWKV(
|
||||
model=body.model,
|
||||
strategy=body.strategy,
|
||||
tokens_path=get_tokens_path(body.model),
|
||||
)
|
||||
if "midi" not in body.model.lower()
|
||||
else MusicRWKV(
|
||||
model=body.model,
|
||||
strategy=body.strategy,
|
||||
tokens_path=get_tokens_path(body.model),
|
||||
|
||||
Reference in New Issue
Block a user