expose penalty_decay, top_k
This commit is contained in:
parent
7cba526913
commit
843840baa0
@ -70,10 +70,10 @@ class ChatCompletionBody(ModelConfigBody):
|
|||||||
"assistant_name": None,
|
"assistant_name": None,
|
||||||
"presystem": True,
|
"presystem": True,
|
||||||
"max_tokens": 1000,
|
"max_tokens": 1000,
|
||||||
"temperature": 1.2,
|
"temperature": 1,
|
||||||
"top_p": 0.5,
|
"top_p": 0.3,
|
||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0,
|
||||||
"frequency_penalty": 0.4,
|
"frequency_penalty": 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -94,10 +94,10 @@ class CompletionBody(ModelConfigBody):
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
"stop": None,
|
"stop": None,
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"temperature": 1.2,
|
"temperature": 1,
|
||||||
"top_p": 0.5,
|
"top_p": 0.3,
|
||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0,
|
||||||
"frequency_penalty": 0.4,
|
"frequency_penalty": 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,7 @@ class AbstractRWKV(ABC):
|
|||||||
self.top_k = 0
|
self.top_k = 0
|
||||||
self.penalty_alpha_presence = 0
|
self.penalty_alpha_presence = 0
|
||||||
self.penalty_alpha_frequency = 1
|
self.penalty_alpha_frequency = 1
|
||||||
|
self.penalty_decay = 0.996
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def adjust_occurrence(self, occurrence: Dict, token: int):
|
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||||
@ -382,7 +383,7 @@ class TextRWKV(AbstractRWKV):
|
|||||||
|
|
||||||
def adjust_occurrence(self, occurrence: Dict, token: int):
|
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||||
for xxx in occurrence:
|
for xxx in occurrence:
|
||||||
occurrence[xxx] *= 0.996
|
occurrence[xxx] *= self.penalty_decay
|
||||||
if token not in occurrence:
|
if token not in occurrence:
|
||||||
occurrence[token] = 1
|
occurrence[token] = 1
|
||||||
else:
|
else:
|
||||||
@ -399,7 +400,7 @@ class TextRWKV(AbstractRWKV):
|
|||||||
for token in self.model_tokens:
|
for token in self.model_tokens:
|
||||||
token = int(token)
|
token = int(token)
|
||||||
for xxx in occurrence:
|
for xxx in occurrence:
|
||||||
occurrence[xxx] *= 0.996
|
occurrence[xxx] *= self.penalty_decay
|
||||||
if token not in occurrence:
|
if token not in occurrence:
|
||||||
occurrence[token] = 1
|
occurrence[token] = 1
|
||||||
else:
|
else:
|
||||||
@ -664,15 +665,18 @@ class ModelConfigBody(BaseModel):
|
|||||||
top_p: float = Field(default=None, ge=0, le=1)
|
top_p: float = Field(default=None, ge=0, le=1)
|
||||||
presence_penalty: float = Field(default=None, ge=-2, le=2)
|
presence_penalty: float = Field(default=None, ge=-2, le=2)
|
||||||
frequency_penalty: float = Field(default=None, ge=-2, le=2)
|
frequency_penalty: float = Field(default=None, ge=-2, le=2)
|
||||||
|
penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
|
||||||
|
top_k: int = Field(default=None, ge=0, le=25)
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"max_tokens": 1000,
|
"max_tokens": 1000,
|
||||||
"temperature": 1.2,
|
"temperature": 1,
|
||||||
"top_p": 0.5,
|
"top_p": 0.3,
|
||||||
"presence_penalty": 0.4,
|
"presence_penalty": 0,
|
||||||
"frequency_penalty": 0.4,
|
"frequency_penalty": 1,
|
||||||
|
"penalty_decay": 0.996,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -692,6 +696,10 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
|||||||
model.penalty_alpha_presence = body.presence_penalty
|
model.penalty_alpha_presence = body.presence_penalty
|
||||||
if body.frequency_penalty is not None:
|
if body.frequency_penalty is not None:
|
||||||
model.penalty_alpha_frequency = body.frequency_penalty
|
model.penalty_alpha_frequency = body.frequency_penalty
|
||||||
|
if body.penalty_decay is not None:
|
||||||
|
model.penalty_decay = body.penalty_decay
|
||||||
|
if body.top_k is not None:
|
||||||
|
model.top_k = body.top_k
|
||||||
|
|
||||||
|
|
||||||
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
||||||
@ -701,4 +709,6 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
|||||||
top_p=model.top_p,
|
top_p=model.top_p,
|
||||||
presence_penalty=model.penalty_alpha_presence,
|
presence_penalty=model.penalty_alpha_presence,
|
||||||
frequency_penalty=model.penalty_alpha_frequency,
|
frequency_penalty=model.penalty_alpha_frequency,
|
||||||
|
penalty_decay=model.penalty_decay,
|
||||||
|
top_k=model.top_k,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user