fix /update-config can make the default value of unclearly specified fields invalid by passing in None fields
This commit is contained in:
parent
0703993bfd
commit
a1ae71d221
@ -144,6 +144,7 @@ async def eval_rwkv(
|
|||||||
return
|
return
|
||||||
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
||||||
set_rwkv_config(model, body)
|
set_rwkv_config(model, body)
|
||||||
|
print(get_rwkv_config(model))
|
||||||
|
|
||||||
response, prompt_tokens, completion_tokens = "", 0, 0
|
response, prompt_tokens, completion_tokens = "", 0, 0
|
||||||
for response, delta, prompt_tokens, completion_tokens in model.generate(
|
for response, delta, prompt_tokens, completion_tokens in model.generate(
|
||||||
@ -155,12 +156,15 @@ async def eval_rwkv(
|
|||||||
if stream:
|
if stream:
|
||||||
yield json.dumps(
|
yield json.dumps(
|
||||||
{
|
{
|
||||||
"object": "chat.completion.chunk"
|
"object": (
|
||||||
|
"chat.completion.chunk"
|
||||||
if chat_mode
|
if chat_mode
|
||||||
else "text_completion",
|
else "text_completion"
|
||||||
|
),
|
||||||
# "response": response,
|
# "response": response,
|
||||||
"model": model.name,
|
"model": model.name,
|
||||||
"choices": [
|
"choices": [
|
||||||
|
(
|
||||||
{
|
{
|
||||||
"delta": {"content": delta},
|
"delta": {"content": delta},
|
||||||
"index": 0,
|
"index": 0,
|
||||||
@ -172,6 +176,7 @@ async def eval_rwkv(
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
}
|
}
|
||||||
|
)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -193,12 +198,13 @@ async def eval_rwkv(
|
|||||||
if stream:
|
if stream:
|
||||||
yield json.dumps(
|
yield json.dumps(
|
||||||
{
|
{
|
||||||
"object": "chat.completion.chunk"
|
"object": (
|
||||||
if chat_mode
|
"chat.completion.chunk" if chat_mode else "text_completion"
|
||||||
else "text_completion",
|
),
|
||||||
# "response": response,
|
# "response": response,
|
||||||
"model": model.name,
|
"model": model.name,
|
||||||
"choices": [
|
"choices": [
|
||||||
|
(
|
||||||
{
|
{
|
||||||
"delta": {},
|
"delta": {},
|
||||||
"index": 0,
|
"index": 0,
|
||||||
@ -210,6 +216,7 @@ async def eval_rwkv(
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -225,6 +232,7 @@ async def eval_rwkv(
|
|||||||
"total_tokens": prompt_tokens + completion_tokens,
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
},
|
},
|
||||||
"choices": [
|
"choices": [
|
||||||
|
(
|
||||||
{
|
{
|
||||||
"message": {
|
"message": {
|
||||||
"role": Role.Assistant.value,
|
"role": Role.Assistant.value,
|
||||||
@ -239,6 +247,7 @@ async def eval_rwkv(
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,23 +86,41 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
|||||||
|
|
||||||
if body.deploy:
|
if body.deploy:
|
||||||
global_var.set(global_var.Deploy_Mode, True)
|
global_var.set(global_var.Deploy_Mode, True)
|
||||||
if global_var.get(global_var.Model_Config) is None:
|
|
||||||
global_var.set(
|
saved_model_config = global_var.get(global_var.Model_Config)
|
||||||
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
|
init_model_config = get_rwkv_config(global_var.get(global_var.Model))
|
||||||
)
|
if saved_model_config is not None:
|
||||||
|
merge_model(init_model_config, saved_model_config)
|
||||||
|
global_var.set(global_var.Model_Config, init_model_config)
|
||||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
|
def merge_model(to_model: BaseModel, from_model: BaseModel):
|
||||||
|
from_model_fields = [x for x in from_model.dict().keys()]
|
||||||
|
to_model_fields = [x for x in to_model.dict().keys()]
|
||||||
|
|
||||||
|
for field_name in from_model_fields:
|
||||||
|
if field_name in to_model_fields:
|
||||||
|
from_value = getattr(from_model, field_name)
|
||||||
|
|
||||||
|
if from_value is not None:
|
||||||
|
setattr(to_model, field_name, from_value)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/update-config", tags=["Configs"])
|
@router.post("/update-config", tags=["Configs"])
|
||||||
def update_config(body: ModelConfigBody):
|
def update_config(body: ModelConfigBody):
|
||||||
"""
|
"""
|
||||||
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
|
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(body)
|
model_config = global_var.get(global_var.Model_Config)
|
||||||
global_var.set(global_var.Model_Config, body)
|
if model_config is None:
|
||||||
|
model_config = ModelConfigBody()
|
||||||
|
global_var.set(global_var.Model_Config, model_config)
|
||||||
|
merge_model(model_config, body)
|
||||||
|
print("Updated Model Config:", model_config)
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user