safe ModelConfigBody
This commit is contained in:
parent
1f3f6cf9a8
commit
8291c50058
@ -3,7 +3,7 @@ import pathlib
|
||||
import copy
|
||||
from typing import Dict, List
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from rwkv_pip.utils import PIPELINE
|
||||
from routes import state_cache
|
||||
|
||||
@ -182,23 +182,23 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
|
||||
|
||||
class ModelConfigBody(BaseModel):
|
||||
max_tokens: int = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
presence_penalty: float = None
|
||||
frequency_penalty: float = None
|
||||
max_tokens: int = Field(default=None, gt=0, le=102400)
|
||||
temperature: float = Field(default=None, ge=0, le=2)
|
||||
top_p: float = Field(default=None, ge=0, le=1)
|
||||
presence_penalty: float = Field(default=None, ge=-2, le=2)
|
||||
frequency_penalty: float = Field(default=None, ge=-2, le=2)
|
||||
|
||||
|
||||
def set_rwkv_config(model: RWKV, body: ModelConfigBody):
|
||||
if body.max_tokens:
|
||||
if body.max_tokens is not None:
|
||||
model.max_tokens_per_generation = body.max_tokens
|
||||
if body.temperature:
|
||||
if body.temperature is not None:
|
||||
model.temperature = body.temperature
|
||||
if body.top_p:
|
||||
if body.top_p is not None:
|
||||
model.top_p = body.top_p
|
||||
if body.presence_penalty:
|
||||
if body.presence_penalty is not None:
|
||||
model.penalty_alpha_presence = body.presence_penalty
|
||||
if body.frequency_penalty:
|
||||
if body.frequency_penalty is not None:
|
||||
model.penalty_alpha_frequency = body.frequency_penalty
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user