fix /update-config can make the default value of unclearly specified fields invalid by passing in None fields

This commit is contained in:
josc146 2024-02-05 22:27:02 +08:00
parent 0703993bfd
commit a1ae71d221
2 changed files with 75 additions and 48 deletions

View File

@ -144,6 +144,7 @@ async def eval_rwkv(
return return
set_rwkv_config(model, global_var.get(global_var.Model_Config)) set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body) set_rwkv_config(model, body)
print(get_rwkv_config(model))
response, prompt_tokens, completion_tokens = "", 0, 0 response, prompt_tokens, completion_tokens = "", 0, 0
for response, delta, prompt_tokens, completion_tokens in model.generate( for response, delta, prompt_tokens, completion_tokens in model.generate(
@ -155,12 +156,15 @@ async def eval_rwkv(
if stream: if stream:
yield json.dumps( yield json.dumps(
{ {
"object": "chat.completion.chunk" "object": (
"chat.completion.chunk"
if chat_mode if chat_mode
else "text_completion", else "text_completion"
),
# "response": response, # "response": response,
"model": model.name, "model": model.name,
"choices": [ "choices": [
(
{ {
"delta": {"content": delta}, "delta": {"content": delta},
"index": 0, "index": 0,
@ -172,6 +176,7 @@ async def eval_rwkv(
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
} }
)
], ],
} }
) )
@ -193,12 +198,13 @@ async def eval_rwkv(
if stream: if stream:
yield json.dumps( yield json.dumps(
{ {
"object": "chat.completion.chunk" "object": (
if chat_mode "chat.completion.chunk" if chat_mode else "text_completion"
else "text_completion", ),
# "response": response, # "response": response,
"model": model.name, "model": model.name,
"choices": [ "choices": [
(
{ {
"delta": {}, "delta": {},
"index": 0, "index": 0,
@ -210,6 +216,7 @@ async def eval_rwkv(
"index": 0, "index": 0,
"finish_reason": "stop", "finish_reason": "stop",
} }
)
], ],
} }
) )
@ -225,6 +232,7 @@ async def eval_rwkv(
"total_tokens": prompt_tokens + completion_tokens, "total_tokens": prompt_tokens + completion_tokens,
}, },
"choices": [ "choices": [
(
{ {
"message": { "message": {
"role": Role.Assistant.value, "role": Role.Assistant.value,
@ -239,6 +247,7 @@ async def eval_rwkv(
"index": 0, "index": 0,
"finish_reason": "stop", "finish_reason": "stop",
} }
)
], ],
} }

View File

@ -86,23 +86,41 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
if body.deploy: if body.deploy:
global_var.set(global_var.Deploy_Mode, True) global_var.set(global_var.Deploy_Mode, True)
if global_var.get(global_var.Model_Config) is None:
global_var.set( saved_model_config = global_var.get(global_var.Model_Config)
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model)) init_model_config = get_rwkv_config(global_var.get(global_var.Model))
) if saved_model_config is not None:
merge_model(init_model_config, saved_model_config)
global_var.set(global_var.Model_Config, init_model_config)
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working) global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
return "success" return "success"
def merge_model(to_model: BaseModel, from_model: BaseModel):
from_model_fields = [x for x in from_model.dict().keys()]
to_model_fields = [x for x in to_model.dict().keys()]
for field_name in from_model_fields:
if field_name in to_model_fields:
from_value = getattr(from_model, field_name)
if from_value is not None:
setattr(to_model, field_name, from_value)
@router.post("/update-config", tags=["Configs"]) @router.post("/update-config", tags=["Configs"])
def update_config(body: ModelConfigBody): def update_config(body: ModelConfigBody):
""" """
Will not update the model config immediately, but set it when completion called to avoid modifications during generation Will not update the model config immediately, but set it when completion called to avoid modifications during generation
""" """
print(body) model_config = global_var.get(global_var.Model_Config)
global_var.set(global_var.Model_Config, body) if model_config is None:
model_config = ModelConfigBody()
global_var.set(global_var.Model_Config, model_config)
merge_model(model_config, body)
print("Updated Model Config:", model_config)
return "success" return "success"