fix a generation exception caused by potentially dangerous regex being passed into the stop array
This commit is contained in:
		
							parent
							
								
									ac139d5bda
								
							
						
					
					
						commit
						07d4ba0d6b
					
				@ -315,22 +315,25 @@ class AbstractRWKV(ABC):
 | 
			
		||||
                            yield response, "", prompt_token_len, completion_token_len
 | 
			
		||||
                            break
 | 
			
		||||
                    elif type(stop) == list:
 | 
			
		||||
                        stop_exist_regex = "|".join(stop)
 | 
			
		||||
                        matched = re.search(stop_exist_regex, response)
 | 
			
		||||
                        if matched:
 | 
			
		||||
                            try:
 | 
			
		||||
                                state_cache.add_state(
 | 
			
		||||
                                    state_cache.AddStateBody(
 | 
			
		||||
                                        prompt=prompt + response,
 | 
			
		||||
                                        tokens=self.model_tokens,
 | 
			
		||||
                                        state=self.model_state,
 | 
			
		||||
                                        logits=logits,
 | 
			
		||||
                        exit_flag = False
 | 
			
		||||
                        for s in stop:
 | 
			
		||||
                            if s in response:
 | 
			
		||||
                                try:
 | 
			
		||||
                                    state_cache.add_state(
 | 
			
		||||
                                        state_cache.AddStateBody(
 | 
			
		||||
                                            prompt=prompt + response,
 | 
			
		||||
                                            tokens=self.model_tokens,
 | 
			
		||||
                                            state=self.model_state,
 | 
			
		||||
                                            logits=logits,
 | 
			
		||||
                                        )
 | 
			
		||||
                                    )
 | 
			
		||||
                                )
 | 
			
		||||
                            except HTTPException:
 | 
			
		||||
                                pass
 | 
			
		||||
                            response = response.split(matched.group())[0]
 | 
			
		||||
                            yield response, "", prompt_token_len, completion_token_len
 | 
			
		||||
                                except HTTPException:
 | 
			
		||||
                                    pass
 | 
			
		||||
                                exit_flag = True
 | 
			
		||||
                                response = response.split(s)[0]
 | 
			
		||||
                                yield response, "", prompt_token_len, completion_token_len
 | 
			
		||||
                                break
 | 
			
		||||
                        if exit_flag:
 | 
			
		||||
                            break
 | 
			
		||||
                out_last = begin + i + 1
 | 
			
		||||
                if i == self.max_tokens_per_generation - 1:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user