improve dynamic state api
This commit is contained in:
parent
b24a18cd3a
commit
d66fd89947
@ -120,7 +120,9 @@ def update_config(body: ModelConfigBody):
|
|||||||
model_config = ModelConfigBody()
|
model_config = ModelConfigBody()
|
||||||
global_var.set(global_var.Model_Config, model_config)
|
global_var.set(global_var.Model_Config, model_config)
|
||||||
merge_model(model_config, body)
|
merge_model(model_config, body)
|
||||||
exception = load_rwkv_state(global_var.get(global_var.Model), model_config.state)
|
exception = load_rwkv_state(
|
||||||
|
global_var.get(global_var.Model), model_config.state, True
|
||||||
|
)
|
||||||
if exception is not None:
|
if exception is not None:
|
||||||
raise exception
|
raise exception
|
||||||
print("Updated Model Config:", model_config)
|
print("Updated Model Config:", model_config)
|
||||||
|
@ -716,7 +716,9 @@ class ModelConfigBody(BaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
|
def load_rwkv_state(
|
||||||
|
model: AbstractRWKV, state_path: str, print_log: bool = True
|
||||||
|
) -> HTTPException:
|
||||||
if model:
|
if model:
|
||||||
if state_path:
|
if state_path:
|
||||||
if model.model_path.endswith(".pth") and state_path.endswith(".pth"):
|
if model.model_path.endswith(".pth") and state_path.endswith(".pth"):
|
||||||
@ -726,7 +728,18 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
|
|||||||
if model.state_path == state_path:
|
if model.state_path == state_path:
|
||||||
return
|
return
|
||||||
|
|
||||||
state_raw = torch.load(state_path, map_location="cpu")
|
if not os.path.isfile(state_path):
|
||||||
|
return HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST, "state file not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
state_raw = torch.load(state_path, map_location="cpu")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST, "state file failed to load"
|
||||||
|
)
|
||||||
state_raw_shape = next(iter(state_raw.values())).shape
|
state_raw_shape = next(iter(state_raw.values())).shape
|
||||||
|
|
||||||
args = model.model.args
|
args = model.model.args
|
||||||
@ -736,7 +749,7 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
|
|||||||
):
|
):
|
||||||
if model.state_path:
|
if model.state_path:
|
||||||
pass
|
pass
|
||||||
else:
|
elif print_log:
|
||||||
print("state failed to load")
|
print("state failed to load")
|
||||||
return HTTPException(
|
return HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST, "state shape mismatch"
|
status.HTTP_400_BAD_REQUEST, "state shape mismatch"
|
||||||
@ -765,23 +778,27 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
|
|||||||
|
|
||||||
state_cache.force_reset_state()
|
state_cache.force_reset_state()
|
||||||
model.state_path = state_path
|
model.state_path = state_path
|
||||||
print("state loaded")
|
if print_log:
|
||||||
|
print("state loaded")
|
||||||
else:
|
else:
|
||||||
if model.state_path:
|
if model.state_path:
|
||||||
pass
|
pass
|
||||||
else:
|
elif print_log:
|
||||||
print("state failed to load")
|
print("state failed to load")
|
||||||
return HTTPException(
|
return HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
"file format of the model or state model not supported",
|
"file format of the model or state model not supported",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state_cache.force_reset_state()
|
if state_path == "" and model.state_path != "":
|
||||||
model.state_path = ""
|
state_cache.force_reset_state()
|
||||||
model.state_tuned = None # TODO cached
|
model.state_path = ""
|
||||||
print("state unloaded")
|
model.state_tuned = None # TODO cached
|
||||||
|
if print_log:
|
||||||
|
print("state unloaded")
|
||||||
else:
|
else:
|
||||||
print("state not loaded")
|
if print_log:
|
||||||
|
print("state not loaded")
|
||||||
|
|
||||||
|
|
||||||
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
||||||
@ -805,7 +822,7 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
|||||||
if body.global_penalty is not None:
|
if body.global_penalty is not None:
|
||||||
model.global_penalty = body.global_penalty
|
model.global_penalty = body.global_penalty
|
||||||
if body.state is not None:
|
if body.state is not None:
|
||||||
load_rwkv_state(model, body.state)
|
load_rwkv_state(model, body.state, False)
|
||||||
|
|
||||||
|
|
||||||
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
||||||
|
Loading…
Reference in New Issue
Block a user