diff --git a/finetune/lora/train.py b/finetune/lora/train.py index 8582da8..99c49e9 100644 --- a/finetune/lora/train.py +++ b/finetune/lora/train.py @@ -390,6 +390,7 @@ if __name__ == "__main__": rank_zero_info(f"########## Loading {args.load_model}... ##########") try: load_dict = torch.load(args.load_model, map_location="cpu") + model.load_state_dict(load_dict, strict=(not args.lora)) except: rank_zero_info(f"Bad checkpoint {args.load_model}") if args.my_pile_stage >= 2: # try again using another checkpoint @@ -401,14 +402,16 @@ if __name__ == "__main__": args.epoch_begin = max_p + 1 rank_zero_info(f"Trying {args.load_model}") load_dict = torch.load(args.load_model, map_location="cpu") + model.load_state_dict(load_dict, strict=(not args.lora)) if args.load_partial == 1: load_keys = load_dict.keys() for k in model.state_dict(): if k not in load_keys: load_dict[k] = model.state_dict()[k] + model.load_state_dict(load_dict, strict=(not args.lora)) # If using LoRA, the LoRA keys might be missing in the original model - model.load_state_dict(load_dict, strict=(not args.lora)) + # model.load_state_dict(load_dict, strict=(not args.lora)) if os.path.isfile(args.lora_load): model.load_state_dict( torch.load(args.lora_load, map_location="cpu"), strict=False