From 07d4ba0d6b71d0cd834ab959cc88a284900a1916 Mon Sep 17 00:00:00 2001 From: josc146 Date: Mon, 4 Mar 2024 21:20:53 +0800 Subject: [PATCH] fix a generation exception caused by potentially dangerous regex being passed into the stop array --- backend-python/utils/rwkv.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 8a0cd7b..0429037 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -315,22 +315,25 @@ class AbstractRWKV(ABC): yield response, "", prompt_token_len, completion_token_len break elif type(stop) == list: - stop_exist_regex = "|".join(stop) - matched = re.search(stop_exist_regex, response) - if matched: - try: - state_cache.add_state( - state_cache.AddStateBody( - prompt=prompt + response, - tokens=self.model_tokens, - state=self.model_state, - logits=logits, + exit_flag = False + for s in stop: + if s in response: + try: + state_cache.add_state( + state_cache.AddStateBody( + prompt=prompt + response, + tokens=self.model_tokens, + state=self.model_state, + logits=logits, + ) ) - ) - except HTTPException: - pass - response = response.split(matched.group())[0] - yield response, "", prompt_token_len, completion_token_len + except HTTPException: + pass + exit_flag = True + response = response.split(s)[0] + yield response, "", prompt_token_len, completion_token_len + break + if exit_flag: break out_last = begin + i + 1 if i == self.max_tokens_per_generation - 1: