expose global_penalty

This commit is contained in:
josc146
2024-03-02 17:50:41 +08:00
parent 53a5574080
commit 4f14074a75
6 changed files with 92 additions and 31 deletions

View File

@@ -40,6 +40,7 @@ class AbstractRWKV(ABC):
self.penalty_alpha_presence = 0
self.penalty_alpha_frequency = 1
self.penalty_decay = 0.996
self.global_penalty = False
@abstractmethod
def adjust_occurrence(self, occurrence: Dict, token: int):
@@ -403,8 +404,8 @@ class TextRWKV(AbstractRWKV):
+ occurrence[n] * self.penalty_alpha_frequency
)
# comment the codes below to get the same generated results as the official RWKV Gradio
if i == 0:
# set global_penalty to False to get the same generated results as the official RWKV Gradio
if self.global_penalty and i == 0:
for token in self.model_tokens:
token = int(token)
if token not in self.AVOID_PENALTY_TOKENS:
@@ -673,6 +674,7 @@ class ModelConfigBody(BaseModel):
frequency_penalty: float = Field(default=None, ge=-2, le=2)
penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
top_k: int = Field(default=None, ge=0, le=25)
global_penalty: bool = Field(default=None)
model_config = {
"json_schema_extra": {
@@ -683,6 +685,7 @@ class ModelConfigBody(BaseModel):
"presence_penalty": 0,
"frequency_penalty": 1,
"penalty_decay": 0.996,
"global_penalty": False,
}
}
}
@@ -706,6 +709,8 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
model.penalty_decay = body.penalty_decay
if body.top_k is not None:
model.top_k = body.top_k
if body.global_penalty is not None:
model.global_penalty = body.global_penalty
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
@@ -717,4 +722,5 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
frequency_penalty=model.penalty_alpha_frequency,
penalty_decay=model.penalty_decay,
top_k=model.top_k,
global_penalty=model.global_penalty,
)