fix load_state_dict crash
This commit is contained in:
		
							parent
							
								
									d8c70453ec
								
							
						
					
					
						commit
						5ee5fa7e6e
					
				
							
								
								
									
										5
									
								
								finetune/lora/train.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								finetune/lora/train.py
									
									
									
									
										vendored
									
									
								
							@ -390,6 +390,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    rank_zero_info(f"########## Loading {args.load_model}... ##########")
 | 
					    rank_zero_info(f"########## Loading {args.load_model}... ##########")
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        load_dict = torch.load(args.load_model, map_location="cpu")
 | 
					        load_dict = torch.load(args.load_model, map_location="cpu")
 | 
				
			||||||
 | 
					        model.load_state_dict(load_dict, strict=(not args.lora))
 | 
				
			||||||
    except:
 | 
					    except:
 | 
				
			||||||
        rank_zero_info(f"Bad checkpoint {args.load_model}")
 | 
					        rank_zero_info(f"Bad checkpoint {args.load_model}")
 | 
				
			||||||
        if args.my_pile_stage >= 2:  # try again using another checkpoint
 | 
					        if args.my_pile_stage >= 2:  # try again using another checkpoint
 | 
				
			||||||
@ -401,14 +402,16 @@ if __name__ == "__main__":
 | 
				
			|||||||
            args.epoch_begin = max_p + 1
 | 
					            args.epoch_begin = max_p + 1
 | 
				
			||||||
            rank_zero_info(f"Trying {args.load_model}")
 | 
					            rank_zero_info(f"Trying {args.load_model}")
 | 
				
			||||||
            load_dict = torch.load(args.load_model, map_location="cpu")
 | 
					            load_dict = torch.load(args.load_model, map_location="cpu")
 | 
				
			||||||
 | 
					            model.load_state_dict(load_dict, strict=(not args.lora))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if args.load_partial == 1:
 | 
					    if args.load_partial == 1:
 | 
				
			||||||
        load_keys = load_dict.keys()
 | 
					        load_keys = load_dict.keys()
 | 
				
			||||||
        for k in model.state_dict():
 | 
					        for k in model.state_dict():
 | 
				
			||||||
            if k not in load_keys:
 | 
					            if k not in load_keys:
 | 
				
			||||||
                load_dict[k] = model.state_dict()[k]
 | 
					                load_dict[k] = model.state_dict()[k]
 | 
				
			||||||
    # If using LoRA, the LoRA keys might be missing in the original model
 | 
					 | 
				
			||||||
        model.load_state_dict(load_dict, strict=(not args.lora))
 | 
					        model.load_state_dict(load_dict, strict=(not args.lora))
 | 
				
			||||||
 | 
					    # If using LoRA, the LoRA keys might be missing in the original model
 | 
				
			||||||
 | 
					    # model.load_state_dict(load_dict, strict=(not args.lora))
 | 
				
			||||||
    if os.path.isfile(args.lora_load):
 | 
					    if os.path.isfile(args.lora_load):
 | 
				
			||||||
        model.load_state_dict(
 | 
					        model.load_state_dict(
 | 
				
			||||||
            torch.load(args.lora_load, map_location="cpu"), strict=False
 | 
					            torch.load(args.lora_load, map_location="cpu"), strict=False
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user