fix a generation exception caused by potentially dangerous regex being passed into the stop array

This commit is contained in:
josc146 2024-03-04 21:20:53 +08:00
parent ac139d5bda
commit 07d4ba0d6b

View File

@ -315,9 +315,9 @@ class AbstractRWKV(ABC):
yield response, "", prompt_token_len, completion_token_len
break
elif type(stop) == list:
stop_exist_regex = "|".join(stop)
matched = re.search(stop_exist_regex, response)
if matched:
exit_flag = False
for s in stop:
if s in response:
try:
state_cache.add_state(
state_cache.AddStateBody(
@ -329,9 +329,12 @@ class AbstractRWKV(ABC):
)
except HTTPException:
pass
response = response.split(matched.group())[0]
exit_flag = True
response = response.split(s)[0]
yield response, "", prompt_token_len, completion_token_len
break
if exit_flag:
break
out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1:
try: