fix load_state_dict crash
This commit is contained in:
		
							parent
							
								
									d8c70453ec
								
							
						
					
					
						commit
						5ee5fa7e6e
					
				
							
								
								
									
										5
									
								
								finetune/lora/train.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								finetune/lora/train.py
									
									
									
									
										vendored
									
									
								
							@ -390,6 +390,7 @@ if __name__ == "__main__":
 | 
			
		||||
    rank_zero_info(f"########## Loading {args.load_model}... ##########")
 | 
			
		||||
    try:
 | 
			
		||||
        load_dict = torch.load(args.load_model, map_location="cpu")
 | 
			
		||||
        model.load_state_dict(load_dict, strict=(not args.lora))
 | 
			
		||||
    except:
 | 
			
		||||
        rank_zero_info(f"Bad checkpoint {args.load_model}")
 | 
			
		||||
        if args.my_pile_stage >= 2:  # try again using another checkpoint
 | 
			
		||||
@ -401,14 +402,16 @@ if __name__ == "__main__":
 | 
			
		||||
            args.epoch_begin = max_p + 1
 | 
			
		||||
            rank_zero_info(f"Trying {args.load_model}")
 | 
			
		||||
            load_dict = torch.load(args.load_model, map_location="cpu")
 | 
			
		||||
            model.load_state_dict(load_dict, strict=(not args.lora))
 | 
			
		||||
 | 
			
		||||
    if args.load_partial == 1:
 | 
			
		||||
        load_keys = load_dict.keys()
 | 
			
		||||
        for k in model.state_dict():
 | 
			
		||||
            if k not in load_keys:
 | 
			
		||||
                load_dict[k] = model.state_dict()[k]
 | 
			
		||||
        model.load_state_dict(load_dict, strict=(not args.lora))
 | 
			
		||||
    # If using LoRA, the LoRA keys might be missing in the original model
 | 
			
		||||
    model.load_state_dict(load_dict, strict=(not args.lora))
 | 
			
		||||
    # model.load_state_dict(load_dict, strict=(not args.lora))
 | 
			
		||||
    if os.path.isfile(args.lora_load):
 | 
			
		||||
        model.load_state_dict(
 | 
			
		||||
            torch.load(args.lora_load, map_location="cpu"), strict=False
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user