RWKVType now no longer relies on the file name
This commit is contained in:
@@ -10,22 +10,6 @@ import global_var
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_tokens_path(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
|
||||
|
||||
default_tokens_path = tokenizer_dir + "20B_tokenizer.json"
|
||||
|
||||
if "raven" in model_path:
|
||||
return default_tokens_path
|
||||
elif "world" in model_path:
|
||||
return "rwkv_vocab_v20230424"
|
||||
elif "midi" in model_path:
|
||||
return tokenizer_dir + "tokenizer-midi.json"
|
||||
else:
|
||||
return default_tokens_path
|
||||
|
||||
|
||||
class SwitchModelBody(BaseModel):
|
||||
model: str
|
||||
strategy: str
|
||||
@@ -67,25 +51,10 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
||||
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
||||
tokenizer = (
|
||||
get_tokens_path(body.model)
|
||||
if body.tokenizer is None or body.tokenizer == ""
|
||||
else body.tokenizer
|
||||
)
|
||||
try:
|
||||
global_var.set(
|
||||
global_var.Model,
|
||||
TextRWKV(
|
||||
model=body.model,
|
||||
strategy=body.strategy,
|
||||
tokens_path=tokenizer,
|
||||
)
|
||||
if "midi" not in body.model.lower()
|
||||
else MusicRWKV(
|
||||
model=body.model,
|
||||
strategy=body.strategy,
|
||||
tokens_path=tokenizer,
|
||||
),
|
||||
RWKV(model=body.model, strategy=body.strategy, tokenizer=body.tokenizer),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
Reference in New Issue
Block a user