improve dynamic state api

This commit is contained in:
josc146 2024-05-16 13:50:48 +08:00
parent b24a18cd3a
commit d66fd89947
2 changed files with 31 additions and 12 deletions

View File

@ -120,7 +120,9 @@ def update_config(body: ModelConfigBody):
model_config = ModelConfigBody() model_config = ModelConfigBody()
global_var.set(global_var.Model_Config, model_config) global_var.set(global_var.Model_Config, model_config)
merge_model(model_config, body) merge_model(model_config, body)
exception = load_rwkv_state(global_var.get(global_var.Model), model_config.state) exception = load_rwkv_state(
global_var.get(global_var.Model), model_config.state, True
)
if exception is not None: if exception is not None:
raise exception raise exception
print("Updated Model Config:", model_config) print("Updated Model Config:", model_config)

View File

@ -716,7 +716,9 @@ class ModelConfigBody(BaseModel):
} }
def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException: def load_rwkv_state(
model: AbstractRWKV, state_path: str, print_log: bool = True
) -> HTTPException:
if model: if model:
if state_path: if state_path:
if model.model_path.endswith(".pth") and state_path.endswith(".pth"): if model.model_path.endswith(".pth") and state_path.endswith(".pth"):
@ -726,7 +728,18 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
if model.state_path == state_path: if model.state_path == state_path:
return return
state_raw = torch.load(state_path, map_location="cpu") if not os.path.isfile(state_path):
return HTTPException(
status.HTTP_400_BAD_REQUEST, "state file not found"
)
try:
state_raw = torch.load(state_path, map_location="cpu")
except Exception as e:
print(e)
return HTTPException(
status.HTTP_400_BAD_REQUEST, "state file failed to load"
)
state_raw_shape = next(iter(state_raw.values())).shape state_raw_shape = next(iter(state_raw.values())).shape
args = model.model.args args = model.model.args
@ -736,7 +749,7 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
): ):
if model.state_path: if model.state_path:
pass pass
else: elif print_log:
print("state failed to load") print("state failed to load")
return HTTPException( return HTTPException(
status.HTTP_400_BAD_REQUEST, "state shape mismatch" status.HTTP_400_BAD_REQUEST, "state shape mismatch"
@ -765,23 +778,27 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
state_cache.force_reset_state() state_cache.force_reset_state()
model.state_path = state_path model.state_path = state_path
print("state loaded") if print_log:
print("state loaded")
else: else:
if model.state_path: if model.state_path:
pass pass
else: elif print_log:
print("state failed to load") print("state failed to load")
return HTTPException( return HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
"file format of the model or state model not supported", "file format of the model or state model not supported",
) )
else: else:
state_cache.force_reset_state() if state_path == "" and model.state_path != "":
model.state_path = "" state_cache.force_reset_state()
model.state_tuned = None # TODO cached model.state_path = ""
print("state unloaded") model.state_tuned = None # TODO cached
if print_log:
print("state unloaded")
else: else:
print("state not loaded") if print_log:
print("state not loaded")
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody): def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
@ -805,7 +822,7 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
if body.global_penalty is not None: if body.global_penalty is not None:
model.global_penalty = body.global_penalty model.global_penalty = body.global_penalty
if body.state is not None: if body.state is not None:
load_rwkv_state(model, body.state) load_rwkv_state(model, body.state, False)
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody: def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody: