format
This commit is contained in:
		
							parent
							
								
									e930eb5967
								
							
						
					
					
						commit
						d8c70453ec
					
				
							
								
								
									
										198
									
								
								finetune/lora/train.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										198
									
								
								finetune/lora/train.py
									
									
									
									
										vendored
									
									
								
							@ -50,52 +50,84 @@ if __name__ == "__main__":
 | 
			
		||||
    parser = ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--load_model", default="", type=str)  # full path, with .pth
 | 
			
		||||
    parser.add_argument("--wandb", default="", type=str)  # wandb project name. if "" then don't use wandb
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--wandb", default="", type=str
 | 
			
		||||
    )  # wandb project name. if "" then don't use wandb
 | 
			
		||||
    parser.add_argument("--proj_dir", default="out", type=str)
 | 
			
		||||
    parser.add_argument("--random_seed", default="-1", type=int)
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--data_file", default="", type=str)
 | 
			
		||||
    parser.add_argument("--data_type", default="utf-8", type=str)
 | 
			
		||||
    parser.add_argument("--vocab_size", default=0, type=int)  # vocab_size = 0 means auto (for char-level LM and .txt data)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--vocab_size", default=0, type=int
 | 
			
		||||
    )  # vocab_size = 0 means auto (for char-level LM and .txt data)
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--ctx_len", default=1024, type=int)
 | 
			
		||||
    parser.add_argument("--epoch_steps", default=1000, type=int)  # a mini "epoch" has [epoch_steps] steps
 | 
			
		||||
    parser.add_argument("--epoch_count", default=500, type=int)  # train for this many "epochs". will continue afterwards with lr = lr_final
 | 
			
		||||
    parser.add_argument("--epoch_begin", default=0, type=int)  # if you load a model trained for x "epochs", set epoch_begin = x
 | 
			
		||||
    parser.add_argument("--epoch_save", default=5, type=int)  # save the model every [epoch_save] "epochs"
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--epoch_steps", default=1000, type=int
 | 
			
		||||
    )  # a mini "epoch" has [epoch_steps] steps
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--epoch_count", default=500, type=int
 | 
			
		||||
    )  # train for this many "epochs". will continue afterwards with lr = lr_final
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--epoch_begin", default=0, type=int
 | 
			
		||||
    )  # if you load a model trained for x "epochs", set epoch_begin = x
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--epoch_save", default=5, type=int
 | 
			
		||||
    )  # save the model every [epoch_save] "epochs"
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--micro_bsz", default=12, type=int)  # micro batch size (batch size per GPU)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--micro_bsz", default=12, type=int
 | 
			
		||||
    )  # micro batch size (batch size per GPU)
 | 
			
		||||
    parser.add_argument("--n_layer", default=6, type=int)
 | 
			
		||||
    parser.add_argument("--n_embd", default=512, type=int)
 | 
			
		||||
    parser.add_argument("--dim_att", default=0, type=int)
 | 
			
		||||
    parser.add_argument("--dim_ffn", default=0, type=int)
 | 
			
		||||
    parser.add_argument("--pre_ffn", default=0, type=int)  # replace first att layer by ffn (sometimes better)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pre_ffn", default=0, type=int
 | 
			
		||||
    )  # replace first att layer by ffn (sometimes better)
 | 
			
		||||
    parser.add_argument("--head_qk", default=0, type=int)  # my headQK trick
 | 
			
		||||
    parser.add_argument("--tiny_att_dim", default=0, type=int)  # tiny attention dim
 | 
			
		||||
    parser.add_argument("--tiny_att_layer", default=-999, type=int)  # tiny attention @ which layer
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tiny_att_layer", default=-999, type=int
 | 
			
		||||
    )  # tiny attention @ which layer
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--lr_init", default=6e-4, type=float)  # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--lr_init", default=6e-4, type=float
 | 
			
		||||
    )  # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
 | 
			
		||||
    parser.add_argument("--lr_final", default=1e-5, type=float)
 | 
			
		||||
    parser.add_argument("--warmup_steps", default=0, type=int)  # try 50 if you load a model
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--warmup_steps", default=0, type=int
 | 
			
		||||
    )  # try 50 if you load a model
 | 
			
		||||
    parser.add_argument("--beta1", default=0.9, type=float)
 | 
			
		||||
    parser.add_argument("--beta2", default=0.99, type=float)  # use 0.999 when your model is close to convergence
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--beta2", default=0.99, type=float
 | 
			
		||||
    )  # use 0.999 when your model is close to convergence
 | 
			
		||||
    parser.add_argument("--adam_eps", default=1e-8, type=float)
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--grad_cp", default=0, type=int)  # gradient checkpt: saves VRAM, but slower
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--grad_cp", default=0, type=int
 | 
			
		||||
    )  # gradient checkpt: saves VRAM, but slower
 | 
			
		||||
    parser.add_argument("--my_pile_stage", default=0, type=int)  # my special pile mode
 | 
			
		||||
    parser.add_argument("--my_pile_shift", default=-1, type=int)  # my special pile mode - text shift
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--my_pile_shift", default=-1, type=int
 | 
			
		||||
    )  # my special pile mode - text shift
 | 
			
		||||
    parser.add_argument("--my_pile_edecay", default=0, type=int)
 | 
			
		||||
    parser.add_argument("--layerwise_lr", default=1, type=int)  # layerwise lr for faster convergence (but slower it/s)
 | 
			
		||||
    parser.add_argument("--ds_bucket_mb", default=200, type=int)  # deepspeed bucket size in MB. 200 seems enough
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--layerwise_lr", default=1, type=int
 | 
			
		||||
    )  # layerwise lr for faster convergence (but slower it/s)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--ds_bucket_mb", default=200, type=int
 | 
			
		||||
    )  # deepspeed bucket size in MB. 200 seems enough
 | 
			
		||||
    # parser.add_argument("--cuda_cleanup", default=0, type=int)  # extra cuda cleanup (sometimes helpful)
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--my_img_version", default=0, type=str)
 | 
			
		||||
    parser.add_argument("--my_img_size", default=0, type=int)
 | 
			
		||||
    parser.add_argument("--my_img_bit", default=0, type=int)
 | 
			
		||||
    parser.add_argument("--my_img_clip", default='x', type=str)
 | 
			
		||||
    parser.add_argument("--my_img_clip", default="x", type=str)
 | 
			
		||||
    parser.add_argument("--my_img_clip_scale", default=1, type=float)
 | 
			
		||||
    parser.add_argument("--my_img_l1_scale", default=0, type=float)
 | 
			
		||||
    parser.add_argument("--my_img_encoder", default='x', type=str)
 | 
			
		||||
    parser.add_argument("--my_img_encoder", default="x", type=str)
 | 
			
		||||
    # parser.add_argument("--my_img_noise_scale", default=0, type=float)
 | 
			
		||||
    parser.add_argument("--my_sample_len", default=0, type=int)
 | 
			
		||||
    parser.add_argument("--my_ffn_shift", default=1, type=int)
 | 
			
		||||
@ -104,7 +136,7 @@ if __name__ == "__main__":
 | 
			
		||||
    parser.add_argument("--load_partial", default=0, type=int)
 | 
			
		||||
    parser.add_argument("--magic_prime", default=0, type=int)
 | 
			
		||||
    parser.add_argument("--my_qa_mask", default=0, type=int)
 | 
			
		||||
    parser.add_argument("--my_testing", default='', type=str)
 | 
			
		||||
    parser.add_argument("--my_testing", default="", type=str)
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--lora", action="store_true")
 | 
			
		||||
    parser.add_argument("--lora_load", default="", type=str)
 | 
			
		||||
@ -122,18 +154,26 @@ if __name__ == "__main__":
 | 
			
		||||
    import numpy as np
 | 
			
		||||
    import torch
 | 
			
		||||
    from torch.utils.data import DataLoader
 | 
			
		||||
 | 
			
		||||
    if "deepspeed" in args.strategy:
 | 
			
		||||
        import deepspeed
 | 
			
		||||
    import pytorch_lightning as pl
 | 
			
		||||
    from pytorch_lightning import seed_everything
 | 
			
		||||
 | 
			
		||||
    if args.random_seed >= 0:
 | 
			
		||||
        print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
 | 
			
		||||
        print(
 | 
			
		||||
            f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n"
 | 
			
		||||
            * 3
 | 
			
		||||
        )
 | 
			
		||||
        seed_everything(args.random_seed)
 | 
			
		||||
 | 
			
		||||
    np.set_printoptions(precision=4, suppress=True, linewidth=200)
 | 
			
		||||
    warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
 | 
			
		||||
    warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
 | 
			
		||||
    warnings.filterwarnings(
 | 
			
		||||
        "ignore", ".*Consider increasing the value of the `num_workers` argument*"
 | 
			
		||||
    )
 | 
			
		||||
    warnings.filterwarnings(
 | 
			
		||||
        "ignore", ".*The progress bar already tracks a metric with the*"
 | 
			
		||||
    )
 | 
			
		||||
    # os.environ["WDS_SHOW_SEED"] = "1"
 | 
			
		||||
 | 
			
		||||
    args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
 | 
			
		||||
@ -158,7 +198,9 @@ if __name__ == "__main__":
 | 
			
		||||
        args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
 | 
			
		||||
        args.proj_dir = f"{args.proj_dir}-{args.run_name}"
 | 
			
		||||
    else:
 | 
			
		||||
        args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
 | 
			
		||||
        args.run_name = (
 | 
			
		||||
            f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
 | 
			
		||||
        )
 | 
			
		||||
    if not os.path.exists(args.proj_dir):
 | 
			
		||||
        os.makedirs(args.proj_dir)
 | 
			
		||||
 | 
			
		||||
@ -240,24 +282,40 @@ if __name__ == "__main__":
 | 
			
		||||
    )
 | 
			
		||||
    rank_zero_info(str(vars(args)) + "\n")
 | 
			
		||||
 | 
			
		||||
    assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
 | 
			
		||||
    assert args.data_type in [
 | 
			
		||||
        "utf-8",
 | 
			
		||||
        "utf-16le",
 | 
			
		||||
        "numpy",
 | 
			
		||||
        "binidx",
 | 
			
		||||
        "dummy",
 | 
			
		||||
        "wds_img",
 | 
			
		||||
        "uint16",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    if args.lr_final == 0 or args.lr_init == 0:
 | 
			
		||||
        rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
 | 
			
		||||
        rank_zero_info(
 | 
			
		||||
            "\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
 | 
			
		||||
    os.environ["RWKV_FLOAT_MODE"] = args.precision
 | 
			
		||||
    if args.precision == "fp32":
 | 
			
		||||
        for i in range(10):
 | 
			
		||||
            rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
 | 
			
		||||
            rank_zero_info(
 | 
			
		||||
                "\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n"
 | 
			
		||||
            )
 | 
			
		||||
    if args.precision == "fp16":
 | 
			
		||||
        rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
 | 
			
		||||
        rank_zero_info(
 | 
			
		||||
            "\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    os.environ["RWKV_JIT_ON"] = "1"
 | 
			
		||||
    if "deepspeed_stage_3" in args.strategy:
 | 
			
		||||
        os.environ["RWKV_JIT_ON"] = "0"
 | 
			
		||||
    if args.lora and args.grad_cp == 1:
 | 
			
		||||
        print('!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it')
 | 
			
		||||
        print(
 | 
			
		||||
            "!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it"
 | 
			
		||||
        )
 | 
			
		||||
        os.environ["RWKV_JIT_ON"] = "0"
 | 
			
		||||
 | 
			
		||||
    torch.backends.cudnn.benchmark = True
 | 
			
		||||
@ -284,20 +342,22 @@ if __name__ == "__main__":
 | 
			
		||||
    train_data = MyDataset(args)
 | 
			
		||||
    args.vocab_size = train_data.vocab_size
 | 
			
		||||
 | 
			
		||||
    if args.data_type == 'wds_img':
 | 
			
		||||
    if args.data_type == "wds_img":
 | 
			
		||||
        from src.model_img import RWKV_IMG
 | 
			
		||||
 | 
			
		||||
        assert args.lora, "LoRA not yet supported for RWKV_IMG"
 | 
			
		||||
        model = RWKV_IMG(args)
 | 
			
		||||
    else:
 | 
			
		||||
        from src.model import RWKV, LORA_CONFIG, LoraLinear
 | 
			
		||||
 | 
			
		||||
        if args.lora:
 | 
			
		||||
            assert args.lora_r > 0, "LoRA should have its `r` > 0"
 | 
			
		||||
            LORA_CONFIG["r"] = args.lora_r
 | 
			
		||||
            LORA_CONFIG["alpha"] = args.lora_alpha
 | 
			
		||||
            LORA_CONFIG["dropout"] = args.lora_dropout
 | 
			
		||||
            LORA_CONFIG["parts"] = set(str(args.lora_parts).split(','))
 | 
			
		||||
            enable_time_finetune = 'time' in LORA_CONFIG["parts"]
 | 
			
		||||
            enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
 | 
			
		||||
            LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
 | 
			
		||||
            enable_time_finetune = "time" in LORA_CONFIG["parts"]
 | 
			
		||||
            enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
 | 
			
		||||
        model = RWKV(args)
 | 
			
		||||
        # only train lora parameters
 | 
			
		||||
        if args.lora:
 | 
			
		||||
@ -305,20 +365,24 @@ if __name__ == "__main__":
 | 
			
		||||
            for name, module in model.named_modules():
 | 
			
		||||
                # have to check param name since it may have been wrapped by torchscript
 | 
			
		||||
                if any(n.startswith("lora_") for n, _ in module.named_parameters()):
 | 
			
		||||
                    print(f'  LoRA training module {name}')
 | 
			
		||||
                    print(f"  LoRA training module {name}")
 | 
			
		||||
                    for pname, param in module.named_parameters():
 | 
			
		||||
                        param.requires_grad = 'lora_' in pname
 | 
			
		||||
                elif enable_ln_finetune and '.ln' in name:
 | 
			
		||||
                    print(f'  LoRA additionally training module {name}')
 | 
			
		||||
                        param.requires_grad = "lora_" in pname
 | 
			
		||||
                elif enable_ln_finetune and ".ln" in name:
 | 
			
		||||
                    print(f"  LoRA additionally training module {name}")
 | 
			
		||||
                    for param in module.parameters():
 | 
			
		||||
                        param.requires_grad = True
 | 
			
		||||
                elif enable_time_finetune and any(n.startswith("time") for n, _ in module.named_parameters()):
 | 
			
		||||
                elif enable_time_finetune and any(
 | 
			
		||||
                    n.startswith("time") for n, _ in module.named_parameters()
 | 
			
		||||
                ):
 | 
			
		||||
                    for pname, param in module.named_parameters():
 | 
			
		||||
                        if pname.startswith("time"):
 | 
			
		||||
                            print(f'  LoRA additionally training parameter {pname}')
 | 
			
		||||
                            print(f"  LoRA additionally training parameter {pname}")
 | 
			
		||||
                            param.requires_grad = True
 | 
			
		||||
 | 
			
		||||
    if len(args.load_model) == 0 or args.my_pile_stage == 1:  # shall we build the initial weights?
 | 
			
		||||
    if (
 | 
			
		||||
        len(args.load_model) == 0 or args.my_pile_stage == 1
 | 
			
		||||
    ):  # shall we build the initial weights?
 | 
			
		||||
        init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
 | 
			
		||||
        generate_init_weight(model, init_weight_name)  # save initial weights
 | 
			
		||||
        args.load_model = init_weight_name
 | 
			
		||||
@ -346,27 +410,39 @@ if __name__ == "__main__":
 | 
			
		||||
    # If using LoRA, the LoRA keys might be missing in the original model
 | 
			
		||||
    model.load_state_dict(load_dict, strict=(not args.lora))
 | 
			
		||||
    if os.path.isfile(args.lora_load):
 | 
			
		||||
        model.load_state_dict(torch.load(args.lora_load, map_location="cpu"),
 | 
			
		||||
                              strict=False)
 | 
			
		||||
        model.load_state_dict(
 | 
			
		||||
            torch.load(args.lora_load, map_location="cpu"), strict=False
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    trainer: Trainer = Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[train_callback(args)],
 | 
			
		||||
    )
 | 
			
		||||
    
 | 
			
		||||
    if (args.lr_init > 1e-4 or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8):
 | 
			
		||||
        if 'I_KNOW_WHAT_IM_DOING' in os.environ:
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
        args.lr_init > 1e-4
 | 
			
		||||
        or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
 | 
			
		||||
    ):
 | 
			
		||||
        if "I_KNOW_WHAT_IM_DOING" in os.environ:
 | 
			
		||||
            if trainer.global_rank == 0:
 | 
			
		||||
                print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
 | 
			
		||||
                print(f'  WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
 | 
			
		||||
                print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
 | 
			
		||||
                print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
 | 
			
		||||
                print(
 | 
			
		||||
                    f"  WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
 | 
			
		||||
                )
 | 
			
		||||
                print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
 | 
			
		||||
        else:
 | 
			
		||||
            if trainer.global_rank == 0:
 | 
			
		||||
                print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
 | 
			
		||||
                print(f'  ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
 | 
			
		||||
                print(f'  Unless you are sure this is what you want, adjust them accordingly')
 | 
			
		||||
                print(f'  (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")')
 | 
			
		||||
                print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
 | 
			
		||||
                print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
 | 
			
		||||
                print(
 | 
			
		||||
                    f"  ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
 | 
			
		||||
                )
 | 
			
		||||
                print(
 | 
			
		||||
                    f"  Unless you are sure this is what you want, adjust them accordingly"
 | 
			
		||||
                )
 | 
			
		||||
                print(
 | 
			
		||||
                    f'  (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")'
 | 
			
		||||
                )
 | 
			
		||||
                print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
 | 
			
		||||
            exit(0)
 | 
			
		||||
 | 
			
		||||
    if trainer.global_rank == 0:
 | 
			
		||||
@ -379,10 +455,22 @@ if __name__ == "__main__":
 | 
			
		||||
                print(f"{str(shape[0]).ljust(5)}       {n}")
 | 
			
		||||
 | 
			
		||||
    if "deepspeed" in args.strategy:
 | 
			
		||||
        trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
 | 
			
		||||
        trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
 | 
			
		||||
        trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
 | 
			
		||||
            args.ds_bucket_mb * 1000 * 1000
 | 
			
		||||
        )
 | 
			
		||||
        trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
 | 
			
		||||
            args.ds_bucket_mb * 1000 * 1000
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # must set shuffle=False, persistent_workers=False (because worker is in another thread)
 | 
			
		||||
    data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
 | 
			
		||||
    data_loader = DataLoader(
 | 
			
		||||
        train_data,
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        pin_memory=True,
 | 
			
		||||
        batch_size=args.micro_bsz,
 | 
			
		||||
        num_workers=1,
 | 
			
		||||
        persistent_workers=False,
 | 
			
		||||
        drop_last=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    trainer.fit(model, data_loader)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user