expose penalty_decay, top_k
This commit is contained in:
		
							parent
							
								
									7cba526913
								
							
						
					
					
						commit
						843840baa0
					
				@ -70,10 +70,10 @@ class ChatCompletionBody(ModelConfigBody):
 | 
			
		||||
                "assistant_name": None,
 | 
			
		||||
                "presystem": True,
 | 
			
		||||
                "max_tokens": 1000,
 | 
			
		||||
                "temperature": 1.2,
 | 
			
		||||
                "top_p": 0.5,
 | 
			
		||||
                "presence_penalty": 0.4,
 | 
			
		||||
                "frequency_penalty": 0.4,
 | 
			
		||||
                "temperature": 1,
 | 
			
		||||
                "top_p": 0.3,
 | 
			
		||||
                "presence_penalty": 0,
 | 
			
		||||
                "frequency_penalty": 1,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@ -94,10 +94,10 @@ class CompletionBody(ModelConfigBody):
 | 
			
		||||
                "stream": False,
 | 
			
		||||
                "stop": None,
 | 
			
		||||
                "max_tokens": 100,
 | 
			
		||||
                "temperature": 1.2,
 | 
			
		||||
                "top_p": 0.5,
 | 
			
		||||
                "presence_penalty": 0.4,
 | 
			
		||||
                "frequency_penalty": 0.4,
 | 
			
		||||
                "temperature": 1,
 | 
			
		||||
                "top_p": 0.3,
 | 
			
		||||
                "presence_penalty": 0,
 | 
			
		||||
                "frequency_penalty": 1,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -39,6 +39,7 @@ class AbstractRWKV(ABC):
 | 
			
		||||
        self.top_k = 0
 | 
			
		||||
        self.penalty_alpha_presence = 0
 | 
			
		||||
        self.penalty_alpha_frequency = 1
 | 
			
		||||
        self.penalty_decay = 0.996
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def adjust_occurrence(self, occurrence: Dict, token: int):
 | 
			
		||||
@ -382,7 +383,7 @@ class TextRWKV(AbstractRWKV):
 | 
			
		||||
 | 
			
		||||
    def adjust_occurrence(self, occurrence: Dict, token: int):
 | 
			
		||||
        for xxx in occurrence:
 | 
			
		||||
            occurrence[xxx] *= 0.996
 | 
			
		||||
            occurrence[xxx] *= self.penalty_decay
 | 
			
		||||
        if token not in occurrence:
 | 
			
		||||
            occurrence[token] = 1
 | 
			
		||||
        else:
 | 
			
		||||
@ -399,7 +400,7 @@ class TextRWKV(AbstractRWKV):
 | 
			
		||||
            for token in self.model_tokens:
 | 
			
		||||
                token = int(token)
 | 
			
		||||
                for xxx in occurrence:
 | 
			
		||||
                    occurrence[xxx] *= 0.996
 | 
			
		||||
                    occurrence[xxx] *= self.penalty_decay
 | 
			
		||||
                if token not in occurrence:
 | 
			
		||||
                    occurrence[token] = 1
 | 
			
		||||
                else:
 | 
			
		||||
@ -664,15 +665,18 @@ class ModelConfigBody(BaseModel):
 | 
			
		||||
    top_p: float = Field(default=None, ge=0, le=1)
 | 
			
		||||
    presence_penalty: float = Field(default=None, ge=-2, le=2)
 | 
			
		||||
    frequency_penalty: float = Field(default=None, ge=-2, le=2)
 | 
			
		||||
    penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
 | 
			
		||||
    top_k: int = Field(default=None, ge=0, le=25)
 | 
			
		||||
 | 
			
		||||
    model_config = {
 | 
			
		||||
        "json_schema_extra": {
 | 
			
		||||
            "example": {
 | 
			
		||||
                "max_tokens": 1000,
 | 
			
		||||
                "temperature": 1.2,
 | 
			
		||||
                "top_p": 0.5,
 | 
			
		||||
                "presence_penalty": 0.4,
 | 
			
		||||
                "frequency_penalty": 0.4,
 | 
			
		||||
                "temperature": 1,
 | 
			
		||||
                "top_p": 0.3,
 | 
			
		||||
                "presence_penalty": 0,
 | 
			
		||||
                "frequency_penalty": 1,
 | 
			
		||||
                "penalty_decay": 0.996,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@ -692,6 +696,10 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
 | 
			
		||||
        model.penalty_alpha_presence = body.presence_penalty
 | 
			
		||||
    if body.frequency_penalty is not None:
 | 
			
		||||
        model.penalty_alpha_frequency = body.frequency_penalty
 | 
			
		||||
    if body.penalty_decay is not None:
 | 
			
		||||
        model.penalty_decay = body.penalty_decay
 | 
			
		||||
    if body.top_k is not None:
 | 
			
		||||
        model.top_k = body.top_k
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
 | 
			
		||||
@ -701,4 +709,6 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
 | 
			
		||||
        top_p=model.top_p,
 | 
			
		||||
        presence_penalty=model.penalty_alpha_presence,
 | 
			
		||||
        frequency_penalty=model.penalty_alpha_frequency,
 | 
			
		||||
        penalty_decay=model.penalty_decay,
 | 
			
		||||
        top_k=model.top_k,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user