From fe0860dbf0e27e3614eec6b835fb17983d012f41 Mon Sep 17 00:00:00 2001 From: josc146 Date: Thu, 24 Aug 2023 22:49:57 +0800 Subject: [PATCH] fix lora finetune max_epochs (#170) --- finetune/lora/train.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/finetune/lora/train.py b/finetune/lora/train.py index 99c49e9..a03322f 100644 --- a/finetune/lora/train.py +++ b/finetune/lora/train.py @@ -184,7 +184,7 @@ if __name__ == "__main__": args.num_sanity_val_steps = 0 args.check_val_every_n_epoch = int(1e20) args.log_every_n_steps = int(1e20) - args.max_epochs = -1 # continue forever + args.max_epochs = args.epoch_count # continue forever args.betas = (args.beta1, args.beta2) args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz os.environ["RWKV_T_MAX"] = str(args.ctx_len) @@ -373,7 +373,7 @@ if __name__ == "__main__": for param in module.parameters(): param.requires_grad = True elif enable_time_finetune and any( - n.startswith("time") for n, _ in module.named_parameters() + n.startswith("time") for n, _ in module.named_parameters() ): for pname, param in module.named_parameters(): if pname.startswith("time"): @@ -381,7 +381,7 @@ if __name__ == "__main__": param.requires_grad = True if ( - len(args.load_model) == 0 or args.my_pile_stage == 1 + len(args.load_model) == 0 or args.my_pile_stage == 1 ): # shall we build the initial weights? init_weight_name = f"{args.proj_dir}/rwkv-init.pth" generate_init_weight(model, init_weight_name) # save initial weights @@ -423,8 +423,8 @@ if __name__ == "__main__": ) if ( - args.lr_init > 1e-4 - or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8 + args.lr_init > 1e-4 + or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8 ): if "I_KNOW_WHAT_IM_DOING" in os.environ: if trainer.global_rank == 0: @@ -459,10 +459,10 @@ if __name__ == "__main__": if "deepspeed" in args.strategy: trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = ( - args.ds_bucket_mb * 1000 * 1000 + args.ds_bucket_mb * 1000 * 1000 ) trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = ( - args.ds_bucket_mb * 1000 * 1000 + args.ds_bucket_mb * 1000 * 1000 ) # must set shuffle=False, persistent_workers=False (because worker is in another thread)