From a1ae71d2218784dc2285cc73a2c3fca2f1e9d9fc Mon Sep 17 00:00:00 2001 From: josc146 Date: Mon, 5 Feb 2024 22:27:02 +0800 Subject: [PATCH] fix /update-config can make the default value of unclearly specified fields invalid by passing in None fields --- backend-python/routes/completion.py | 93 ++++++++++++++++------------- backend-python/routes/config.py | 30 ++++++++-- 2 files changed, 75 insertions(+), 48 deletions(-) diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index e0beb93..2129c33 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -144,6 +144,7 @@ async def eval_rwkv( return set_rwkv_config(model, global_var.get(global_var.Model_Config)) set_rwkv_config(model, body) + print(get_rwkv_config(model)) response, prompt_tokens, completion_tokens = "", 0, 0 for response, delta, prompt_tokens, completion_tokens in model.generate( @@ -155,23 +156,27 @@ async def eval_rwkv( if stream: yield json.dumps( { - "object": "chat.completion.chunk" - if chat_mode - else "text_completion", + "object": ( + "chat.completion.chunk" + if chat_mode + else "text_completion" + ), # "response": response, "model": model.name, "choices": [ - { - "delta": {"content": delta}, - "index": 0, - "finish_reason": None, - } - if chat_mode - else { - "text": delta, - "index": 0, - "finish_reason": None, - } + ( + { + "delta": {"content": delta}, + "index": 0, + "finish_reason": None, + } + if chat_mode + else { + "text": delta, + "index": 0, + "finish_reason": None, + } + ) ], } ) @@ -193,23 +198,25 @@ async def eval_rwkv( if stream: yield json.dumps( { - "object": "chat.completion.chunk" - if chat_mode - else "text_completion", + "object": ( + "chat.completion.chunk" if chat_mode else "text_completion" + ), # "response": response, "model": model.name, "choices": [ - { - "delta": {}, - "index": 0, - "finish_reason": "stop", - } - if chat_mode - else { - "text": "", - "index": 0, - "finish_reason": "stop", - } + ( + { + "delta": {}, + "index": 0, + "finish_reason": "stop", + } + if chat_mode + else { + "text": "", + "index": 0, + "finish_reason": "stop", + } + ) ], } ) @@ -225,20 +232,22 @@ async def eval_rwkv( "total_tokens": prompt_tokens + completion_tokens, }, "choices": [ - { - "message": { - "role": Role.Assistant.value, - "content": response, - }, - "index": 0, - "finish_reason": "stop", - } - if chat_mode - else { - "text": response, - "index": 0, - "finish_reason": "stop", - } + ( + { + "message": { + "role": Role.Assistant.value, + "content": response, + }, + "index": 0, + "finish_reason": "stop", + } + if chat_mode + else { + "text": response, + "index": 0, + "finish_reason": "stop", + } + ) ], } diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index f031d43..aa0ab1f 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -86,23 +86,41 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request): if body.deploy: global_var.set(global_var.Deploy_Mode, True) - if global_var.get(global_var.Model_Config) is None: - global_var.set( - global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model)) - ) + + saved_model_config = global_var.get(global_var.Model_Config) + init_model_config = get_rwkv_config(global_var.get(global_var.Model)) + if saved_model_config is not None: + merge_model(init_model_config, saved_model_config) + global_var.set(global_var.Model_Config, init_model_config) global_var.set(global_var.Model_Status, global_var.ModelStatus.Working) return "success" +def merge_model(to_model: BaseModel, from_model: BaseModel): + from_model_fields = [x for x in from_model.dict().keys()] + to_model_fields = [x for x in to_model.dict().keys()] + + for field_name in from_model_fields: + if field_name in to_model_fields: + from_value = getattr(from_model, field_name) + + if from_value is not None: + setattr(to_model, field_name, from_value) + + @router.post("/update-config", tags=["Configs"]) def update_config(body: ModelConfigBody): """ Will not update the model config immediately, but set it when completion called to avoid modifications during generation """ - print(body) - global_var.set(global_var.Model_Config, body) + model_config = global_var.get(global_var.Model_Config) + if model_config is None: + model_config = ModelConfigBody() + global_var.set(global_var.Model_Config, model_config) + merge_model(model_config, body) + print("Updated Model Config:", model_config) return "success"