add support for MIDI RWKV

This commit is contained in:
josc146
2023-07-25 16:09:31 +08:00
parent 211ae342af
commit 05b9b42b56
9 changed files with 20373 additions and 100 deletions

View File

@@ -72,7 +72,7 @@ requests_num = 0
async def eval_rwkv(
model: RWKV,
model: AbstractRWKV,
request: Request,
body: ModelConfigBody,
prompt: str,
@@ -209,7 +209,7 @@ async def eval_rwkv(
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
async def chat_completions(body: ChatCompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
model: TextRWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
@@ -302,7 +302,7 @@ The following is a coherent verbose detailed conversation between a girl named {
@router.post("/v1/completions")
@router.post("/completions")
async def completions(body: CompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
model: AbstractRWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
@@ -351,7 +351,7 @@ def embedding_base64(embedding: List[float]) -> str:
@router.post("/v1/engines/text-embedding-ada-002/embeddings")
@router.post("/engines/text-embedding-ada-002/embeddings")
async def embeddings(body: EmbeddingsBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
model: AbstractRWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")

View File

@@ -13,13 +13,16 @@ router = APIRouter()
def get_tokens_path(model_path: str):
model_path = model_path.lower()
default_tokens_path = (
f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
)
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
default_tokens_path = tokenizer_dir + "20B_tokenizer.json"
if "raven" in model_path:
return default_tokens_path
elif "world" in model_path:
return "rwkv_vocab_v20230424"
elif "midi" in model_path:
return tokenizer_dir + "tokenizer-midi.json"
else:
return default_tokens_path
@@ -66,7 +69,13 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
try:
global_var.set(
global_var.Model,
RWKV(
TextRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=get_tokens_path(body.model),
)
if "midi" not in body.model.lower()
else MusicRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=get_tokens_path(body.model),