fix_tokens
This commit is contained in:
		
							parent
							
								
									7bc8da2e29
								
							
						
					
					
						commit
						cf16e54463
					
				@ -10,6 +10,7 @@ from routes import state_cache
 | 
			
		||||
 | 
			
		||||
END_OF_TEXT = 0
 | 
			
		||||
END_OF_LINE = 187
 | 
			
		||||
END_OF_LINE_DOUBLE = 535
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
 | 
			
		||||
@ -77,6 +78,12 @@ The following is a coherent verbose detailed conversation between a girl named {
 | 
			
		||||
        except HTTPException:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
 | 
			
		||||
    def fix_tokens(tokens):
 | 
			
		||||
        if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
 | 
			
		||||
            tokens = tokens[:-1] + [END_OF_LINE, END_OF_LINE]
 | 
			
		||||
        return tokens
 | 
			
		||||
 | 
			
		||||
    def run_rnn(self, _tokens: List[str], newline_adj: int = 0):
 | 
			
		||||
        tokens = [int(x) for x in _tokens]
 | 
			
		||||
        self.model_tokens += tokens
 | 
			
		||||
@ -112,7 +119,7 @@ The following is a coherent verbose detailed conversation between a girl named {
 | 
			
		||||
            logits = copy.deepcopy(cache["logits"])
 | 
			
		||||
 | 
			
		||||
        if delta_prompt != "":
 | 
			
		||||
            logits = self.run_rnn(self.pipeline.encode(delta_prompt))
 | 
			
		||||
            logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(delta_prompt)))
 | 
			
		||||
            try:
 | 
			
		||||
                state_cache.add_state(
 | 
			
		||||
                    state_cache.AddStateBody(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user