fix a generation exception caused by potentially dangerous regex being passed into the stop array

This commit is contained in:
josc146 2024-03-04 21:20:53 +08:00
parent ac139d5bda
commit 07d4ba0d6b

View File

@ -315,9 +315,9 @@ class AbstractRWKV(ABC):
yield response, "", prompt_token_len, completion_token_len yield response, "", prompt_token_len, completion_token_len
break break
elif type(stop) == list: elif type(stop) == list:
stop_exist_regex = "|".join(stop) exit_flag = False
matched = re.search(stop_exist_regex, response) for s in stop:
if matched: if s in response:
try: try:
state_cache.add_state( state_cache.add_state(
state_cache.AddStateBody( state_cache.AddStateBody(
@ -329,9 +329,12 @@ class AbstractRWKV(ABC):
) )
except HTTPException: except HTTPException:
pass pass
response = response.split(matched.group())[0] exit_flag = True
response = response.split(s)[0]
yield response, "", prompt_token_len, completion_token_len yield response, "", prompt_token_len, completion_token_len
break break
if exit_flag:
break
out_last = begin + i + 1 out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1: if i == self.max_tokens_per_generation - 1:
try: try: