expose penalty_decay, top_k
This commit is contained in:
parent
7cba526913
commit
843840baa0
@ -70,10 +70,10 @@ class ChatCompletionBody(ModelConfigBody):
|
||||
"assistant_name": None,
|
||||
"presystem": True,
|
||||
"max_tokens": 1000,
|
||||
"temperature": 1.2,
|
||||
"top_p": 0.5,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.4,
|
||||
"temperature": 1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -94,10 +94,10 @@ class CompletionBody(ModelConfigBody):
|
||||
"stream": False,
|
||||
"stop": None,
|
||||
"max_tokens": 100,
|
||||
"temperature": 1.2,
|
||||
"top_p": 0.5,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.4,
|
||||
"temperature": 1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ class AbstractRWKV(ABC):
|
||||
self.top_k = 0
|
||||
self.penalty_alpha_presence = 0
|
||||
self.penalty_alpha_frequency = 1
|
||||
self.penalty_decay = 0.996
|
||||
|
||||
@abstractmethod
|
||||
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||
@ -382,7 +383,7 @@ class TextRWKV(AbstractRWKV):
|
||||
|
||||
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||
for xxx in occurrence:
|
||||
occurrence[xxx] *= 0.996
|
||||
occurrence[xxx] *= self.penalty_decay
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
else:
|
||||
@ -399,7 +400,7 @@ class TextRWKV(AbstractRWKV):
|
||||
for token in self.model_tokens:
|
||||
token = int(token)
|
||||
for xxx in occurrence:
|
||||
occurrence[xxx] *= 0.996
|
||||
occurrence[xxx] *= self.penalty_decay
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
else:
|
||||
@ -664,15 +665,18 @@ class ModelConfigBody(BaseModel):
|
||||
top_p: float = Field(default=None, ge=0, le=1)
|
||||
presence_penalty: float = Field(default=None, ge=-2, le=2)
|
||||
frequency_penalty: float = Field(default=None, ge=-2, le=2)
|
||||
penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
|
||||
top_k: int = Field(default=None, ge=0, le=25)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 1.2,
|
||||
"top_p": 0.5,
|
||||
"presence_penalty": 0.4,
|
||||
"frequency_penalty": 0.4,
|
||||
"temperature": 1,
|
||||
"top_p": 0.3,
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 1,
|
||||
"penalty_decay": 0.996,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -692,6 +696,10 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
||||
model.penalty_alpha_presence = body.presence_penalty
|
||||
if body.frequency_penalty is not None:
|
||||
model.penalty_alpha_frequency = body.frequency_penalty
|
||||
if body.penalty_decay is not None:
|
||||
model.penalty_decay = body.penalty_decay
|
||||
if body.top_k is not None:
|
||||
model.top_k = body.top_k
|
||||
|
||||
|
||||
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
||||
@ -701,4 +709,6 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
||||
top_p=model.top_p,
|
||||
presence_penalty=model.penalty_alpha_presence,
|
||||
frequency_penalty=model.penalty_alpha_frequency,
|
||||
penalty_decay=model.penalty_decay,
|
||||
top_k=model.top_k,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user