diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index b8065b9..97dbc9e 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -3,7 +3,7 @@ import pathlib import copy from typing import Dict, List from fastapi import HTTPException -from pydantic import BaseModel +from pydantic import BaseModel, Field from rwkv_pip.utils import PIPELINE from routes import state_cache @@ -182,23 +182,23 @@ The following is a coherent verbose detailed conversation between a girl named { class ModelConfigBody(BaseModel): - max_tokens: int = None - temperature: float = None - top_p: float = None - presence_penalty: float = None - frequency_penalty: float = None + max_tokens: int = Field(default=None, gt=0, le=102400) + temperature: float = Field(default=None, ge=0, le=2) + top_p: float = Field(default=None, ge=0, le=1) + presence_penalty: float = Field(default=None, ge=-2, le=2) + frequency_penalty: float = Field(default=None, ge=-2, le=2) def set_rwkv_config(model: RWKV, body: ModelConfigBody): - if body.max_tokens: + if body.max_tokens is not None: model.max_tokens_per_generation = body.max_tokens - if body.temperature: + if body.temperature is not None: model.temperature = body.temperature - if body.top_p: + if body.top_p is not None: model.top_p = body.top_p - if body.presence_penalty: + if body.presence_penalty is not None: model.penalty_alpha_presence = body.presence_penalty - if body.frequency_penalty: + if body.frequency_penalty is not None: model.penalty_alpha_frequency = body.frequency_penalty