fix lora finetune max_epochs (#170)
This commit is contained in:
parent
02d5d641d1
commit
fe0860dbf0
14
finetune/lora/train.py
vendored
14
finetune/lora/train.py
vendored
@ -184,7 +184,7 @@ if __name__ == "__main__":
|
||||
args.num_sanity_val_steps = 0
|
||||
args.check_val_every_n_epoch = int(1e20)
|
||||
args.log_every_n_steps = int(1e20)
|
||||
args.max_epochs = -1 # continue forever
|
||||
args.max_epochs = args.epoch_count # continue forever
|
||||
args.betas = (args.beta1, args.beta2)
|
||||
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
||||
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
||||
@ -373,7 +373,7 @@ if __name__ == "__main__":
|
||||
for param in module.parameters():
|
||||
param.requires_grad = True
|
||||
elif enable_time_finetune and any(
|
||||
n.startswith("time") for n, _ in module.named_parameters()
|
||||
n.startswith("time") for n, _ in module.named_parameters()
|
||||
):
|
||||
for pname, param in module.named_parameters():
|
||||
if pname.startswith("time"):
|
||||
@ -381,7 +381,7 @@ if __name__ == "__main__":
|
||||
param.requires_grad = True
|
||||
|
||||
if (
|
||||
len(args.load_model) == 0 or args.my_pile_stage == 1
|
||||
len(args.load_model) == 0 or args.my_pile_stage == 1
|
||||
): # shall we build the initial weights?
|
||||
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
||||
generate_init_weight(model, init_weight_name) # save initial weights
|
||||
@ -423,8 +423,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if (
|
||||
args.lr_init > 1e-4
|
||||
or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
|
||||
args.lr_init > 1e-4
|
||||
or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
|
||||
):
|
||||
if "I_KNOW_WHAT_IM_DOING" in os.environ:
|
||||
if trainer.global_rank == 0:
|
||||
@ -459,10 +459,10 @@ if __name__ == "__main__":
|
||||
|
||||
if "deepspeed" in args.strategy:
|
||||
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
)
|
||||
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
args.ds_bucket_mb * 1000 * 1000
|
||||
)
|
||||
|
||||
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
||||
|
Loading…
Reference in New Issue
Block a user