preliminary usable features
This commit is contained in:
@@ -1,5 +1,37 @@
|
||||
from typing import Dict
|
||||
from langchain.llms import RWKV
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelConfigBody(BaseModel):
|
||||
max_tokens: int = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
presence_penalty: float = None
|
||||
frequency_penalty: float = None
|
||||
|
||||
|
||||
def set_rwkv_config(model: RWKV, body: ModelConfigBody):
|
||||
if body.max_tokens:
|
||||
model.max_tokens_per_generation = body.max_tokens
|
||||
if body.temperature:
|
||||
model.temperature = body.temperature
|
||||
if body.top_p:
|
||||
model.top_p = body.top_p
|
||||
if body.presence_penalty:
|
||||
model.penalty_alpha_presence = body.presence_penalty
|
||||
if body.frequency_penalty:
|
||||
model.penalty_alpha_frequency = body.frequency_penalty
|
||||
|
||||
|
||||
def get_rwkv_config(model: RWKV) -> ModelConfigBody:
|
||||
return ModelConfigBody(
|
||||
max_tokens=model.max_tokens_per_generation,
|
||||
temperature=model.temperature,
|
||||
top_p=model.top_p,
|
||||
presence_penalty=model.penalty_alpha_presence,
|
||||
frequency_penalty=model.penalty_alpha_frequency,
|
||||
)
|
||||
|
||||
|
||||
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
|
||||
|
||||
Reference in New Issue
Block a user