custom tokenizer (#77)

This commit is contained in:
josc146
2023-09-16 00:34:11 +08:00
parent 971124d0d7
commit a25965530c
6 changed files with 77 additions and 6 deletions

View File

@@ -29,6 +29,7 @@ def get_tokens_path(model_path: str):
class SwitchModelBody(BaseModel):
model: str
strategy: str
tokenizer: Union[str, None] = None
customCuda: bool = False
class Config:
@@ -36,6 +37,7 @@ class SwitchModelBody(BaseModel):
"example": {
"model": "models/RWKV-4-World-3B-v1-20230619-ctx4096.pth",
"strategy": "cuda fp16",
"tokenizer": None,
"customCuda": False,
}
}
@@ -65,19 +67,24 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
tokenizer = (
get_tokens_path(body.model)
if body.tokenizer is None or body.tokenizer == ""
else body.tokenizer
)
try:
global_var.set(
global_var.Model,
TextRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=get_tokens_path(body.model),
tokens_path=tokenizer,
)
if "midi" not in body.model.lower()
else MusicRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=get_tokens_path(body.model),
tokens_path=tokenizer,
),
)
except Exception as e: