expose global_penalty
This commit is contained in:
@@ -40,6 +40,7 @@ class AbstractRWKV(ABC):
|
||||
self.penalty_alpha_presence = 0
|
||||
self.penalty_alpha_frequency = 1
|
||||
self.penalty_decay = 0.996
|
||||
self.global_penalty = False
|
||||
|
||||
@abstractmethod
|
||||
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||
@@ -403,8 +404,8 @@ class TextRWKV(AbstractRWKV):
|
||||
+ occurrence[n] * self.penalty_alpha_frequency
|
||||
)
|
||||
|
||||
# comment the codes below to get the same generated results as the official RWKV Gradio
|
||||
if i == 0:
|
||||
# set global_penalty to False to get the same generated results as the official RWKV Gradio
|
||||
if self.global_penalty and i == 0:
|
||||
for token in self.model_tokens:
|
||||
token = int(token)
|
||||
if token not in self.AVOID_PENALTY_TOKENS:
|
||||
@@ -673,6 +674,7 @@ class ModelConfigBody(BaseModel):
|
||||
frequency_penalty: float = Field(default=None, ge=-2, le=2)
|
||||
penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
|
||||
top_k: int = Field(default=None, ge=0, le=25)
|
||||
global_penalty: bool = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
@@ -683,6 +685,7 @@ class ModelConfigBody(BaseModel):
|
||||
"presence_penalty": 0,
|
||||
"frequency_penalty": 1,
|
||||
"penalty_decay": 0.996,
|
||||
"global_penalty": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -706,6 +709,8 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
||||
model.penalty_decay = body.penalty_decay
|
||||
if body.top_k is not None:
|
||||
model.top_k = body.top_k
|
||||
if body.global_penalty is not None:
|
||||
model.global_penalty = body.global_penalty
|
||||
|
||||
|
||||
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
||||
@@ -717,4 +722,5 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
||||
frequency_penalty=model.penalty_alpha_frequency,
|
||||
penalty_decay=model.penalty_decay,
|
||||
top_k=model.top_k,
|
||||
global_penalty=model.global_penalty,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user