fix a generation exception caused by potentially dangerous regex being passed into the stop array

This commit is contained in:
josc146 2024-03-04 21:20:53 +08:00
parent ac139d5bda
commit 07d4ba0d6b

View File

@ -315,22 +315,25 @@ class AbstractRWKV(ABC):
yield response, "", prompt_token_len, completion_token_len yield response, "", prompt_token_len, completion_token_len
break break
elif type(stop) == list: elif type(stop) == list:
stop_exist_regex = "|".join(stop) exit_flag = False
matched = re.search(stop_exist_regex, response) for s in stop:
if matched: if s in response:
try: try:
state_cache.add_state( state_cache.add_state(
state_cache.AddStateBody( state_cache.AddStateBody(
prompt=prompt + response, prompt=prompt + response,
tokens=self.model_tokens, tokens=self.model_tokens,
state=self.model_state, state=self.model_state,
logits=logits, logits=logits,
)
) )
) except HTTPException:
except HTTPException: pass
pass exit_flag = True
response = response.split(matched.group())[0] response = response.split(s)[0]
yield response, "", prompt_token_len, completion_token_len yield response, "", prompt_token_len, completion_token_len
break
if exit_flag:
break break
out_last = begin + i + 1 out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1: if i == self.max_tokens_per_generation - 1: