From d8c70453ecc5a65ded16c8e1db9f81376e98942d Mon Sep 17 00:00:00 2001 From: josc146 Date: Sun, 9 Jul 2023 12:32:50 +0800 Subject: [PATCH] format --- finetune/lora/train.py | 198 +++++++++++++++++++++++++++++------------ 1 file changed, 143 insertions(+), 55 deletions(-) diff --git a/finetune/lora/train.py b/finetune/lora/train.py index 0548dab..8582da8 100644 --- a/finetune/lora/train.py +++ b/finetune/lora/train.py @@ -50,52 +50,84 @@ if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--load_model", default="", type=str) # full path, with .pth - parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb + parser.add_argument( + "--wandb", default="", type=str + ) # wandb project name. if "" then don't use wandb parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--random_seed", default="-1", type=int) parser.add_argument("--data_file", default="", type=str) parser.add_argument("--data_type", default="utf-8", type=str) - parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) + parser.add_argument( + "--vocab_size", default=0, type=int + ) # vocab_size = 0 means auto (for char-level LM and .txt data) parser.add_argument("--ctx_len", default=1024, type=int) - parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps - parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final - parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x - parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" + parser.add_argument( + "--epoch_steps", default=1000, type=int + ) # a mini "epoch" has [epoch_steps] steps + parser.add_argument( + "--epoch_count", default=500, type=int + ) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument( + "--epoch_begin", default=0, type=int + ) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument( + "--epoch_save", default=5, type=int + ) # save the model every [epoch_save] "epochs" - parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) + parser.add_argument( + "--micro_bsz", default=12, type=int + ) # micro batch size (batch size per GPU) parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--dim_att", default=0, type=int) parser.add_argument("--dim_ffn", default=0, type=int) - parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) + parser.add_argument( + "--pre_ffn", default=0, type=int + ) # replace first att layer by ffn (sometimes better) parser.add_argument("--head_qk", default=0, type=int) # my headQK trick parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim - parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer + parser.add_argument( + "--tiny_att_layer", default=-999, type=int + ) # tiny attention @ which layer - parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument( + "--lr_init", default=6e-4, type=float + ) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument("--lr_final", default=1e-5, type=float) - parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model + parser.add_argument( + "--warmup_steps", default=0, type=int + ) # try 50 if you load a model parser.add_argument("--beta1", default=0.9, type=float) - parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence + parser.add_argument( + "--beta2", default=0.99, type=float + ) # use 0.999 when your model is close to convergence parser.add_argument("--adam_eps", default=1e-8, type=float) - parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument( + "--grad_cp", default=0, type=int + ) # gradient checkpt: saves VRAM, but slower parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode - parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift + parser.add_argument( + "--my_pile_shift", default=-1, type=int + ) # my special pile mode - text shift parser.add_argument("--my_pile_edecay", default=0, type=int) - parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) - parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough + parser.add_argument( + "--layerwise_lr", default=1, type=int + ) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument( + "--ds_bucket_mb", default=200, type=int + ) # deepspeed bucket size in MB. 200 seems enough # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) parser.add_argument("--my_img_version", default=0, type=str) parser.add_argument("--my_img_size", default=0, type=int) parser.add_argument("--my_img_bit", default=0, type=int) - parser.add_argument("--my_img_clip", default='x', type=str) + parser.add_argument("--my_img_clip", default="x", type=str) parser.add_argument("--my_img_clip_scale", default=1, type=float) parser.add_argument("--my_img_l1_scale", default=0, type=float) - parser.add_argument("--my_img_encoder", default='x', type=str) + parser.add_argument("--my_img_encoder", default="x", type=str) # parser.add_argument("--my_img_noise_scale", default=0, type=float) parser.add_argument("--my_sample_len", default=0, type=int) parser.add_argument("--my_ffn_shift", default=1, type=int) @@ -104,7 +136,7 @@ if __name__ == "__main__": parser.add_argument("--load_partial", default=0, type=int) parser.add_argument("--magic_prime", default=0, type=int) parser.add_argument("--my_qa_mask", default=0, type=int) - parser.add_argument("--my_testing", default='', type=str) + parser.add_argument("--my_testing", default="", type=str) parser.add_argument("--lora", action="store_true") parser.add_argument("--lora_load", default="", type=str) @@ -122,18 +154,26 @@ if __name__ == "__main__": import numpy as np import torch from torch.utils.data import DataLoader + if "deepspeed" in args.strategy: import deepspeed import pytorch_lightning as pl from pytorch_lightning import seed_everything if args.random_seed >= 0: - print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) + print( + f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" + * 3 + ) seed_everything(args.random_seed) np.set_printoptions(precision=4, suppress=True, linewidth=200) - warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") - warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + warnings.filterwarnings( + "ignore", ".*Consider increasing the value of the `num_workers` argument*" + ) + warnings.filterwarnings( + "ignore", ".*The progress bar already tracks a metric with the*" + ) # os.environ["WDS_SHOW_SEED"] = "1" args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") @@ -158,7 +198,9 @@ if __name__ == "__main__": args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" args.proj_dir = f"{args.proj_dir}-{args.run_name}" else: - args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + args.run_name = ( + f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + ) if not os.path.exists(args.proj_dir): os.makedirs(args.proj_dir) @@ -240,24 +282,40 @@ if __name__ == "__main__": ) rank_zero_info(str(vars(args)) + "\n") - assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] + assert args.data_type in [ + "utf-8", + "utf-16le", + "numpy", + "binidx", + "dummy", + "wds_img", + "uint16", + ] if args.lr_final == 0 or args.lr_init == 0: - rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") + rank_zero_info( + "\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n" + ) assert args.precision in ["fp32", "tf32", "fp16", "bf16"] os.environ["RWKV_FLOAT_MODE"] = args.precision if args.precision == "fp32": for i in range(10): - rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") + rank_zero_info( + "\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n" + ) if args.precision == "fp16": - rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") + rank_zero_info( + "\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n" + ) os.environ["RWKV_JIT_ON"] = "1" if "deepspeed_stage_3" in args.strategy: os.environ["RWKV_JIT_ON"] = "0" if args.lora and args.grad_cp == 1: - print('!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it') + print( + "!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it" + ) os.environ["RWKV_JIT_ON"] = "0" torch.backends.cudnn.benchmark = True @@ -284,20 +342,22 @@ if __name__ == "__main__": train_data = MyDataset(args) args.vocab_size = train_data.vocab_size - if args.data_type == 'wds_img': + if args.data_type == "wds_img": from src.model_img import RWKV_IMG + assert args.lora, "LoRA not yet supported for RWKV_IMG" model = RWKV_IMG(args) else: from src.model import RWKV, LORA_CONFIG, LoraLinear + if args.lora: assert args.lora_r > 0, "LoRA should have its `r` > 0" LORA_CONFIG["r"] = args.lora_r LORA_CONFIG["alpha"] = args.lora_alpha LORA_CONFIG["dropout"] = args.lora_dropout - LORA_CONFIG["parts"] = set(str(args.lora_parts).split(',')) - enable_time_finetune = 'time' in LORA_CONFIG["parts"] - enable_ln_finetune = 'ln' in LORA_CONFIG["parts"] + LORA_CONFIG["parts"] = set(str(args.lora_parts).split(",")) + enable_time_finetune = "time" in LORA_CONFIG["parts"] + enable_ln_finetune = "ln" in LORA_CONFIG["parts"] model = RWKV(args) # only train lora parameters if args.lora: @@ -305,20 +365,24 @@ if __name__ == "__main__": for name, module in model.named_modules(): # have to check param name since it may have been wrapped by torchscript if any(n.startswith("lora_") for n, _ in module.named_parameters()): - print(f' LoRA training module {name}') + print(f" LoRA training module {name}") for pname, param in module.named_parameters(): - param.requires_grad = 'lora_' in pname - elif enable_ln_finetune and '.ln' in name: - print(f' LoRA additionally training module {name}') + param.requires_grad = "lora_" in pname + elif enable_ln_finetune and ".ln" in name: + print(f" LoRA additionally training module {name}") for param in module.parameters(): param.requires_grad = True - elif enable_time_finetune and any(n.startswith("time") for n, _ in module.named_parameters()): + elif enable_time_finetune and any( + n.startswith("time") for n, _ in module.named_parameters() + ): for pname, param in module.named_parameters(): if pname.startswith("time"): - print(f' LoRA additionally training parameter {pname}') + print(f" LoRA additionally training parameter {pname}") param.requires_grad = True - if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights? + if ( + len(args.load_model) == 0 or args.my_pile_stage == 1 + ): # shall we build the initial weights? init_weight_name = f"{args.proj_dir}/rwkv-init.pth" generate_init_weight(model, init_weight_name) # save initial weights args.load_model = init_weight_name @@ -346,27 +410,39 @@ if __name__ == "__main__": # If using LoRA, the LoRA keys might be missing in the original model model.load_state_dict(load_dict, strict=(not args.lora)) if os.path.isfile(args.lora_load): - model.load_state_dict(torch.load(args.lora_load, map_location="cpu"), - strict=False) + model.load_state_dict( + torch.load(args.lora_load, map_location="cpu"), strict=False + ) trainer: Trainer = Trainer.from_argparse_args( args, callbacks=[train_callback(args)], ) - - if (args.lr_init > 1e-4 or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8): - if 'I_KNOW_WHAT_IM_DOING' in os.environ: + + if ( + args.lr_init > 1e-4 + or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8 + ): + if "I_KNOW_WHAT_IM_DOING" in os.environ: if trainer.global_rank == 0: - print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') - print(f' WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)') - print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print( + f" WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)" + ) + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") else: if trainer.global_rank == 0: - print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') - print(f' ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)') - print(f' Unless you are sure this is what you want, adjust them accordingly') - print(f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")') - print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print( + f" ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)" + ) + print( + f" Unless you are sure this is what you want, adjust them accordingly" + ) + print( + f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")' + ) + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") exit(0) if trainer.global_rank == 0: @@ -379,10 +455,22 @@ if __name__ == "__main__": print(f"{str(shape[0]).ljust(5)} {n}") if "deepspeed" in args.strategy: - trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) # must set shuffle=False, persistent_workers=False (because worker is in another thread) - data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) + data_loader = DataLoader( + train_data, + shuffle=False, + pin_memory=True, + batch_size=args.micro_bsz, + num_workers=1, + persistent_workers=False, + drop_last=True, + ) trainer.fit(model, data_loader)