fix load_state_dict crash
This commit is contained in:
parent
d8c70453ec
commit
5ee5fa7e6e
5
finetune/lora/train.py
vendored
5
finetune/lora/train.py
vendored
@ -390,6 +390,7 @@ if __name__ == "__main__":
|
|||||||
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
||||||
try:
|
try:
|
||||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||||
|
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||||
except:
|
except:
|
||||||
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
||||||
if args.my_pile_stage >= 2: # try again using another checkpoint
|
if args.my_pile_stage >= 2: # try again using another checkpoint
|
||||||
@ -401,14 +402,16 @@ if __name__ == "__main__":
|
|||||||
args.epoch_begin = max_p + 1
|
args.epoch_begin = max_p + 1
|
||||||
rank_zero_info(f"Trying {args.load_model}")
|
rank_zero_info(f"Trying {args.load_model}")
|
||||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||||
|
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||||
|
|
||||||
if args.load_partial == 1:
|
if args.load_partial == 1:
|
||||||
load_keys = load_dict.keys()
|
load_keys = load_dict.keys()
|
||||||
for k in model.state_dict():
|
for k in model.state_dict():
|
||||||
if k not in load_keys:
|
if k not in load_keys:
|
||||||
load_dict[k] = model.state_dict()[k]
|
load_dict[k] = model.state_dict()[k]
|
||||||
|
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||||
# If using LoRA, the LoRA keys might be missing in the original model
|
# If using LoRA, the LoRA keys might be missing in the original model
|
||||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
# model.load_state_dict(load_dict, strict=(not args.lora))
|
||||||
if os.path.isfile(args.lora_load):
|
if os.path.isfile(args.lora_load):
|
||||||
model.load_state_dict(
|
model.load_state_dict(
|
||||||
torch.load(args.lora_load, map_location="cpu"), strict=False
|
torch.load(args.lora_load, map_location="cpu"), strict=False
|
||||||
|
Loading…
Reference in New Issue
Block a user