fix load_state_dict crash
This commit is contained in:
parent
d8c70453ec
commit
5ee5fa7e6e
5
finetune/lora/train.py
vendored
5
finetune/lora/train.py
vendored
@ -390,6 +390,7 @@ if __name__ == "__main__":
|
||||
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
||||
try:
|
||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
except:
|
||||
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
||||
if args.my_pile_stage >= 2: # try again using another checkpoint
|
||||
@ -401,14 +402,16 @@ if __name__ == "__main__":
|
||||
args.epoch_begin = max_p + 1
|
||||
rank_zero_info(f"Trying {args.load_model}")
|
||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
|
||||
if args.load_partial == 1:
|
||||
load_keys = load_dict.keys()
|
||||
for k in model.state_dict():
|
||||
if k not in load_keys:
|
||||
load_dict[k] = model.state_dict()[k]
|
||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
# If using LoRA, the LoRA keys might be missing in the original model
|
||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
# model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
if os.path.isfile(args.lora_load):
|
||||
model.load_state_dict(
|
||||
torch.load(args.lora_load, map_location="cpu"), strict=False
|
||||
|
Loading…
Reference in New Issue
Block a user