improve dynamic state api
This commit is contained in:
		
							parent
							
								
									b24a18cd3a
								
							
						
					
					
						commit
						d66fd89947
					
				@ -120,7 +120,9 @@ def update_config(body: ModelConfigBody):
 | 
			
		||||
        model_config = ModelConfigBody()
 | 
			
		||||
        global_var.set(global_var.Model_Config, model_config)
 | 
			
		||||
    merge_model(model_config, body)
 | 
			
		||||
    exception = load_rwkv_state(global_var.get(global_var.Model), model_config.state)
 | 
			
		||||
    exception = load_rwkv_state(
 | 
			
		||||
        global_var.get(global_var.Model), model_config.state, True
 | 
			
		||||
    )
 | 
			
		||||
    if exception is not None:
 | 
			
		||||
        raise exception
 | 
			
		||||
    print("Updated Model Config:", model_config)
 | 
			
		||||
 | 
			
		||||
@ -716,7 +716,9 @@ class ModelConfigBody(BaseModel):
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
 | 
			
		||||
def load_rwkv_state(
 | 
			
		||||
    model: AbstractRWKV, state_path: str, print_log: bool = True
 | 
			
		||||
) -> HTTPException:
 | 
			
		||||
    if model:
 | 
			
		||||
        if state_path:
 | 
			
		||||
            if model.model_path.endswith(".pth") and state_path.endswith(".pth"):
 | 
			
		||||
@ -726,7 +728,18 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
 | 
			
		||||
                if model.state_path == state_path:
 | 
			
		||||
                    return
 | 
			
		||||
 | 
			
		||||
                state_raw = torch.load(state_path, map_location="cpu")
 | 
			
		||||
                if not os.path.isfile(state_path):
 | 
			
		||||
                    return HTTPException(
 | 
			
		||||
                        status.HTTP_400_BAD_REQUEST, "state file not found"
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    state_raw = torch.load(state_path, map_location="cpu")
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    print(e)
 | 
			
		||||
                    return HTTPException(
 | 
			
		||||
                        status.HTTP_400_BAD_REQUEST, "state file failed to load"
 | 
			
		||||
                    )
 | 
			
		||||
                state_raw_shape = next(iter(state_raw.values())).shape
 | 
			
		||||
 | 
			
		||||
                args = model.model.args
 | 
			
		||||
@ -736,7 +749,7 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
 | 
			
		||||
                ):
 | 
			
		||||
                    if model.state_path:
 | 
			
		||||
                        pass
 | 
			
		||||
                    else:
 | 
			
		||||
                    elif print_log:
 | 
			
		||||
                        print("state failed to load")
 | 
			
		||||
                    return HTTPException(
 | 
			
		||||
                        status.HTTP_400_BAD_REQUEST, "state shape mismatch"
 | 
			
		||||
@ -765,23 +778,27 @@ def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
 | 
			
		||||
 | 
			
		||||
                state_cache.force_reset_state()
 | 
			
		||||
                model.state_path = state_path
 | 
			
		||||
                print("state loaded")
 | 
			
		||||
                if print_log:
 | 
			
		||||
                    print("state loaded")
 | 
			
		||||
            else:
 | 
			
		||||
                if model.state_path:
 | 
			
		||||
                    pass
 | 
			
		||||
                else:
 | 
			
		||||
                elif print_log:
 | 
			
		||||
                    print("state failed to load")
 | 
			
		||||
                return HTTPException(
 | 
			
		||||
                    status.HTTP_400_BAD_REQUEST,
 | 
			
		||||
                    "file format of the model or state model not supported",
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            state_cache.force_reset_state()
 | 
			
		||||
            model.state_path = ""
 | 
			
		||||
            model.state_tuned = None  # TODO cached
 | 
			
		||||
            print("state unloaded")
 | 
			
		||||
            if state_path == "" and model.state_path != "":
 | 
			
		||||
                state_cache.force_reset_state()
 | 
			
		||||
                model.state_path = ""
 | 
			
		||||
                model.state_tuned = None  # TODO cached
 | 
			
		||||
                if print_log:
 | 
			
		||||
                    print("state unloaded")
 | 
			
		||||
    else:
 | 
			
		||||
        print("state not loaded")
 | 
			
		||||
        if print_log:
 | 
			
		||||
            print("state not loaded")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
 | 
			
		||||
@ -805,7 +822,7 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
 | 
			
		||||
    if body.global_penalty is not None:
 | 
			
		||||
        model.global_penalty = body.global_penalty
 | 
			
		||||
    if body.state is not None:
 | 
			
		||||
        load_rwkv_state(model, body.state)
 | 
			
		||||
        load_rwkv_state(model, body.state, False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user