diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py index c3ca814..0429fa7 100644 --- a/backend-python/rwkv_pip/utils.py +++ b/backend-python/rwkv_pip/utils.py @@ -33,7 +33,7 @@ class PIPELINE_ARGS: class PIPELINE: - def __init__(self, model, WORD_NAME): + def __init__(self, model, WORD_NAME: str): self.model = model if WORD_NAME == "cl100k_base": import tiktoken @@ -47,9 +47,15 @@ class PIPELINE: os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt" ) else: - from tokenizers import Tokenizer + if WORD_NAME.endswith(".txt"): + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from rwkv_tokenizer import TRIE_TOKENIZER - self.tokenizer = Tokenizer.from_file(WORD_NAME) + self.tokenizer = TRIE_TOKENIZER(WORD_NAME) + else: + from tokenizers import Tokenizer + + self.tokenizer = Tokenizer.from_file(WORD_NAME) def refine_context(self, context): context = context.strip().split("\n")