diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index d3528b5..6645d83 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -10,6 +10,7 @@ from routes import state_cache END_OF_TEXT = 0 END_OF_LINE = 187 +END_OF_LINE_DOUBLE = 535 os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}" @@ -77,6 +78,12 @@ The following is a coherent verbose detailed conversation between a girl named { except HTTPException: pass + # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end + def fix_tokens(tokens): + if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE: + tokens = tokens[:-1] + [END_OF_LINE, END_OF_LINE] + return tokens + def run_rnn(self, _tokens: List[str], newline_adj: int = 0): tokens = [int(x) for x in _tokens] self.model_tokens += tokens @@ -112,7 +119,7 @@ The following is a coherent verbose detailed conversation between a girl named { logits = copy.deepcopy(cache["logits"]) if delta_prompt != "": - logits = self.run_rnn(self.pipeline.encode(delta_prompt)) + logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(delta_prompt))) try: state_cache.add_state( state_cache.AddStateBody(