support for rwkv-4-world

This commit is contained in:
josc146
2023-05-28 12:53:14 +08:00
parent b7fb8ed898
commit 94971bb666
8 changed files with 65918 additions and 65 deletions

View File

@@ -11,6 +11,19 @@ import GPUtil
router = APIRouter()
def get_tokens_path(model_path: str):
model_path = model_path.lower()
default_tokens_path = (
f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
)
if "raven" in model_path:
return default_tokens_path
elif "world" in model_path:
return "rwkv_vocab_v20230424"
else:
return default_tokens_path
class SwitchModelBody(BaseModel):
model: str
strategy: str
@@ -36,7 +49,7 @@ def switch_model(body: SwitchModelBody, response: Response):
RWKV(
model=body.model,
strategy=body.strategy,
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json",
tokens_path=get_tokens_path(body.model),
),
)
except Exception as e: