From d66fd89947f84640c8f7e8a9da75f1b213b84ed1 Mon Sep 17 00:00:00 2001 From: josc146 Date: Thu, 16 May 2024 13:50:48 +0800 Subject: [PATCH] improve dynamic state api --- backend-python/routes/config.py | 4 +++- backend-python/utils/rwkv.py | 39 +++++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index b1c196a..601014d 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -120,7 +120,9 @@ def update_config(body: ModelConfigBody): model_config = ModelConfigBody() global_var.set(global_var.Model_Config, model_config) merge_model(model_config, body) - exception = load_rwkv_state(global_var.get(global_var.Model), model_config.state) + exception = load_rwkv_state( + global_var.get(global_var.Model), model_config.state, True + ) if exception is not None: raise exception print("Updated Model Config:", model_config) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 76438fc..430fee4 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -716,7 +716,9 @@ class ModelConfigBody(BaseModel): } -def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException: +def load_rwkv_state( + model: AbstractRWKV, state_path: str, print_log: bool = True +) -> HTTPException: if model: if state_path: if model.model_path.endswith(".pth") and state_path.endswith(".pth"): @@ -726,7 +728,18 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException: if model.state_path == state_path: return - state_raw = torch.load(state_path, map_location="cpu") + if not os.path.isfile(state_path): + return HTTPException( + status.HTTP_400_BAD_REQUEST, "state file not found" + ) + + try: + state_raw = torch.load(state_path, map_location="cpu") + except Exception as e: + print(e) + return HTTPException( + status.HTTP_400_BAD_REQUEST, "state file failed to load" + ) state_raw_shape = next(iter(state_raw.values())).shape args = model.model.args @@ -736,7 +749,7 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException: ): if model.state_path: pass - else: + elif print_log: print("state failed to load") return HTTPException( status.HTTP_400_BAD_REQUEST, "state shape mismatch" @@ -765,23 +778,27 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException: state_cache.force_reset_state() model.state_path = state_path - print("state loaded") + if print_log: + print("state loaded") else: if model.state_path: pass - else: + elif print_log: print("state failed to load") return HTTPException( status.HTTP_400_BAD_REQUEST, "file format of the model or state model not supported", ) else: - state_cache.force_reset_state() - model.state_path = "" - model.state_tuned = None # TODO cached - print("state unloaded") + if state_path == "" and model.state_path != "": + state_cache.force_reset_state() + model.state_path = "" + model.state_tuned = None # TODO cached + if print_log: + print("state unloaded") else: - print("state not loaded") + if print_log: + print("state not loaded") def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody): @@ -805,7 +822,7 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody): if body.global_penalty is not None: model.global_penalty = body.global_penalty if body.state is not None: - load_rwkv_state(model, body.state) + load_rwkv_state(model, body.state, False) def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody: