RWKVType now no longer relies on the file name

This commit is contained in:
josc146
2023-10-26 16:55:33 +08:00
parent 1d7f19ffaf
commit 627a20936d
2 changed files with 65 additions and 62 deletions

View File

@@ -10,22 +10,6 @@ import global_var
router = APIRouter()
def get_tokens_path(model_path: str):
model_path = model_path.lower()
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
default_tokens_path = tokenizer_dir + "20B_tokenizer.json"
if "raven" in model_path:
return default_tokens_path
elif "world" in model_path:
return "rwkv_vocab_v20230424"
elif "midi" in model_path:
return tokenizer_dir + "tokenizer-midi.json"
else:
return default_tokens_path
class SwitchModelBody(BaseModel):
model: str
strategy: str
@@ -67,25 +51,10 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
tokenizer = (
get_tokens_path(body.model)
if body.tokenizer is None or body.tokenizer == ""
else body.tokenizer
)
try:
global_var.set(
global_var.Model,
TextRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=tokenizer,
)
if "midi" not in body.model.lower()
else MusicRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=tokenizer,
),
RWKV(model=body.model, strategy=body.strategy, tokenizer=body.tokenizer),
)
except Exception as e:
print(e)