From cf16e54463c96521ed9f93343da922bf36cf52df Mon Sep 17 00:00:00 2001 From: josc146 Date: Wed, 31 May 2023 14:55:13 +0800 Subject: [PATCH] fix_tokens --- backend-python/utils/rwkv.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index d3528b5..6645d83 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -10,6 +10,7 @@ from routes import state_cache END_OF_TEXT = 0 END_OF_LINE = 187 +END_OF_LINE_DOUBLE = 535 os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}" @@ -77,6 +78,12 @@ The following is a coherent verbose detailed conversation between a girl named { except HTTPException: pass + # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end + def fix_tokens(tokens): + if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE: + tokens = tokens[:-1] + [END_OF_LINE, END_OF_LINE] + return tokens + def run_rnn(self, _tokens: List[str], newline_adj: int = 0): tokens = [int(x) for x in _tokens] self.model_tokens += tokens @@ -112,7 +119,7 @@ The following is a coherent verbose detailed conversation between a girl named { logits = copy.deepcopy(cache["logits"]) if delta_prompt != "": - logits = self.run_rnn(self.pipeline.encode(delta_prompt)) + logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(delta_prompt))) try: state_cache.add_state( state_cache.AddStateBody(