fix load_state_dict crash

This commit is contained in:
josc146 2023-07-09 12:33:29 +08:00
parent d8c70453ec
commit 5ee5fa7e6e

View File

@ -390,6 +390,7 @@ if __name__ == "__main__":
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
model.load_state_dict(load_dict, strict=(not args.lora))
except:
rank_zero_info(f"Bad checkpoint {args.load_model}")
if args.my_pile_stage >= 2: # try again using another checkpoint
@ -401,14 +402,16 @@ if __name__ == "__main__":
args.epoch_begin = max_p + 1
rank_zero_info(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu")
model.load_state_dict(load_dict, strict=(not args.lora))
if args.load_partial == 1:
load_keys = load_dict.keys()
for k in model.state_dict():
if k not in load_keys:
load_dict[k] = model.state_dict()[k]
model.load_state_dict(load_dict, strict=(not args.lora))
# If using LoRA, the LoRA keys might be missing in the original model
model.load_state_dict(load_dict, strict=(not args.lora))
# model.load_state_dict(load_dict, strict=(not args.lora))
if os.path.isfile(args.lora_load):
model.load_state_dict(
torch.load(args.lora_load, map_location="cpu"), strict=False