safe ModelConfigBody

This commit is contained in:
josc146 2023-05-30 23:13:27 +08:00
parent 1f3f6cf9a8
commit 8291c50058

View File

@ -3,7 +3,7 @@ import pathlib
import copy import copy
from typing import Dict, List from typing import Dict, List
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel from pydantic import BaseModel, Field
from rwkv_pip.utils import PIPELINE from rwkv_pip.utils import PIPELINE
from routes import state_cache from routes import state_cache
@ -182,23 +182,23 @@ The following is a coherent verbose detailed conversation between a girl named {
class ModelConfigBody(BaseModel): class ModelConfigBody(BaseModel):
max_tokens: int = None max_tokens: int = Field(default=None, gt=0, le=102400)
temperature: float = None temperature: float = Field(default=None, ge=0, le=2)
top_p: float = None top_p: float = Field(default=None, ge=0, le=1)
presence_penalty: float = None presence_penalty: float = Field(default=None, ge=-2, le=2)
frequency_penalty: float = None frequency_penalty: float = Field(default=None, ge=-2, le=2)
def set_rwkv_config(model: RWKV, body: ModelConfigBody): def set_rwkv_config(model: RWKV, body: ModelConfigBody):
if body.max_tokens: if body.max_tokens is not None:
model.max_tokens_per_generation = body.max_tokens model.max_tokens_per_generation = body.max_tokens
if body.temperature: if body.temperature is not None:
model.temperature = body.temperature model.temperature = body.temperature
if body.top_p: if body.top_p is not None:
model.top_p = body.top_p model.top_p = body.top_p
if body.presence_penalty: if body.presence_penalty is not None:
model.penalty_alpha_presence = body.presence_penalty model.penalty_alpha_presence = body.presence_penalty
if body.frequency_penalty: if body.frequency_penalty is not None:
model.penalty_alpha_frequency = body.frequency_penalty model.penalty_alpha_frequency = body.frequency_penalty