safe ModelConfigBody

This commit is contained in:
josc146 2023-05-30 23:13:27 +08:00
parent 1f3f6cf9a8
commit 8291c50058

View File

@ -3,7 +3,7 @@ import pathlib
import copy
from typing import Dict, List
from fastapi import HTTPException
from pydantic import BaseModel
from pydantic import BaseModel, Field
from rwkv_pip.utils import PIPELINE
from routes import state_cache
@ -182,23 +182,23 @@ The following is a coherent verbose detailed conversation between a girl named {
class ModelConfigBody(BaseModel):
max_tokens: int = None
temperature: float = None
top_p: float = None
presence_penalty: float = None
frequency_penalty: float = None
max_tokens: int = Field(default=None, gt=0, le=102400)
temperature: float = Field(default=None, ge=0, le=2)
top_p: float = Field(default=None, ge=0, le=1)
presence_penalty: float = Field(default=None, ge=-2, le=2)
frequency_penalty: float = Field(default=None, ge=-2, le=2)
def set_rwkv_config(model: RWKV, body: ModelConfigBody):
if body.max_tokens:
if body.max_tokens is not None:
model.max_tokens_per_generation = body.max_tokens
if body.temperature:
if body.temperature is not None:
model.temperature = body.temperature
if body.top_p:
if body.top_p is not None:
model.top_p = body.top_p
if body.presence_penalty:
if body.presence_penalty is not None:
model.penalty_alpha_presence = body.presence_penalty
if body.frequency_penalty:
if body.frequency_penalty is not None:
model.penalty_alpha_frequency = body.frequency_penalty