fix /update-config can make the default value of unclearly specified fields invalid by passing in None fields

This commit is contained in:
josc146 2024-02-05 22:27:02 +08:00
parent 0703993bfd
commit a1ae71d221
2 changed files with 75 additions and 48 deletions

View File

@ -144,6 +144,7 @@ async def eval_rwkv(
return
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
print(get_rwkv_config(model))
response, prompt_tokens, completion_tokens = "", 0, 0
for response, delta, prompt_tokens, completion_tokens in model.generate(
@ -155,12 +156,15 @@ async def eval_rwkv(
if stream:
yield json.dumps(
{
"object": "chat.completion.chunk"
"object": (
"chat.completion.chunk"
if chat_mode
else "text_completion",
else "text_completion"
),
# "response": response,
"model": model.name,
"choices": [
(
{
"delta": {"content": delta},
"index": 0,
@ -172,6 +176,7 @@ async def eval_rwkv(
"index": 0,
"finish_reason": None,
}
)
],
}
)
@ -193,12 +198,13 @@ async def eval_rwkv(
if stream:
yield json.dumps(
{
"object": "chat.completion.chunk"
if chat_mode
else "text_completion",
"object": (
"chat.completion.chunk" if chat_mode else "text_completion"
),
# "response": response,
"model": model.name,
"choices": [
(
{
"delta": {},
"index": 0,
@ -210,6 +216,7 @@ async def eval_rwkv(
"index": 0,
"finish_reason": "stop",
}
)
],
}
)
@ -225,6 +232,7 @@ async def eval_rwkv(
"total_tokens": prompt_tokens + completion_tokens,
},
"choices": [
(
{
"message": {
"role": Role.Assistant.value,
@ -239,6 +247,7 @@ async def eval_rwkv(
"index": 0,
"finish_reason": "stop",
}
)
],
}

View File

@ -86,23 +86,41 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
if body.deploy:
global_var.set(global_var.Deploy_Mode, True)
if global_var.get(global_var.Model_Config) is None:
global_var.set(
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
)
saved_model_config = global_var.get(global_var.Model_Config)
init_model_config = get_rwkv_config(global_var.get(global_var.Model))
if saved_model_config is not None:
merge_model(init_model_config, saved_model_config)
global_var.set(global_var.Model_Config, init_model_config)
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
return "success"
def merge_model(to_model: BaseModel, from_model: BaseModel):
from_model_fields = [x for x in from_model.dict().keys()]
to_model_fields = [x for x in to_model.dict().keys()]
for field_name in from_model_fields:
if field_name in to_model_fields:
from_value = getattr(from_model, field_name)
if from_value is not None:
setattr(to_model, field_name, from_value)
@router.post("/update-config", tags=["Configs"])
def update_config(body: ModelConfigBody):
"""
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
"""
print(body)
global_var.set(global_var.Model_Config, body)
model_config = global_var.get(global_var.Model_Config)
if model_config is None:
model_config = ModelConfigBody()
global_var.set(global_var.Model_Config, model_config)
merge_model(model_config, body)
print("Updated Model Config:", model_config)
return "success"