format
This commit is contained in:
parent
e930eb5967
commit
d8c70453ec
196
finetune/lora/train.py
vendored
196
finetune/lora/train.py
vendored
@ -50,52 +50,84 @@ if __name__ == "__main__":
|
|||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
||||||
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
|
parser.add_argument(
|
||||||
|
"--wandb", default="", type=str
|
||||||
|
) # wandb project name. if "" then don't use wandb
|
||||||
parser.add_argument("--proj_dir", default="out", type=str)
|
parser.add_argument("--proj_dir", default="out", type=str)
|
||||||
parser.add_argument("--random_seed", default="-1", type=int)
|
parser.add_argument("--random_seed", default="-1", type=int)
|
||||||
|
|
||||||
parser.add_argument("--data_file", default="", type=str)
|
parser.add_argument("--data_file", default="", type=str)
|
||||||
parser.add_argument("--data_type", default="utf-8", type=str)
|
parser.add_argument("--data_type", default="utf-8", type=str)
|
||||||
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
parser.add_argument(
|
||||||
|
"--vocab_size", default=0, type=int
|
||||||
|
) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
||||||
|
|
||||||
parser.add_argument("--ctx_len", default=1024, type=int)
|
parser.add_argument("--ctx_len", default=1024, type=int)
|
||||||
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
|
parser.add_argument(
|
||||||
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
|
"--epoch_steps", default=1000, type=int
|
||||||
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
|
) # a mini "epoch" has [epoch_steps] steps
|
||||||
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
|
parser.add_argument(
|
||||||
|
"--epoch_count", default=500, type=int
|
||||||
|
) # train for this many "epochs". will continue afterwards with lr = lr_final
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_begin", default=0, type=int
|
||||||
|
) # if you load a model trained for x "epochs", set epoch_begin = x
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_save", default=5, type=int
|
||||||
|
) # save the model every [epoch_save] "epochs"
|
||||||
|
|
||||||
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
parser.add_argument(
|
||||||
|
"--micro_bsz", default=12, type=int
|
||||||
|
) # micro batch size (batch size per GPU)
|
||||||
parser.add_argument("--n_layer", default=6, type=int)
|
parser.add_argument("--n_layer", default=6, type=int)
|
||||||
parser.add_argument("--n_embd", default=512, type=int)
|
parser.add_argument("--n_embd", default=512, type=int)
|
||||||
parser.add_argument("--dim_att", default=0, type=int)
|
parser.add_argument("--dim_att", default=0, type=int)
|
||||||
parser.add_argument("--dim_ffn", default=0, type=int)
|
parser.add_argument("--dim_ffn", default=0, type=int)
|
||||||
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
|
parser.add_argument(
|
||||||
|
"--pre_ffn", default=0, type=int
|
||||||
|
) # replace first att layer by ffn (sometimes better)
|
||||||
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
||||||
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
||||||
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
|
parser.add_argument(
|
||||||
|
"--tiny_att_layer", default=-999, type=int
|
||||||
|
) # tiny attention @ which layer
|
||||||
|
|
||||||
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
parser.add_argument(
|
||||||
|
"--lr_init", default=6e-4, type=float
|
||||||
|
) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
||||||
parser.add_argument("--lr_final", default=1e-5, type=float)
|
parser.add_argument("--lr_final", default=1e-5, type=float)
|
||||||
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
|
parser.add_argument(
|
||||||
|
"--warmup_steps", default=0, type=int
|
||||||
|
) # try 50 if you load a model
|
||||||
parser.add_argument("--beta1", default=0.9, type=float)
|
parser.add_argument("--beta1", default=0.9, type=float)
|
||||||
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
|
parser.add_argument(
|
||||||
|
"--beta2", default=0.99, type=float
|
||||||
|
) # use 0.999 when your model is close to convergence
|
||||||
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
||||||
|
|
||||||
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
parser.add_argument(
|
||||||
|
"--grad_cp", default=0, type=int
|
||||||
|
) # gradient checkpt: saves VRAM, but slower
|
||||||
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
||||||
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
|
parser.add_argument(
|
||||||
|
"--my_pile_shift", default=-1, type=int
|
||||||
|
) # my special pile mode - text shift
|
||||||
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
||||||
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
parser.add_argument(
|
||||||
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
"--layerwise_lr", default=1, type=int
|
||||||
|
) # layerwise lr for faster convergence (but slower it/s)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_bucket_mb", default=200, type=int
|
||||||
|
) # deepspeed bucket size in MB. 200 seems enough
|
||||||
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
||||||
|
|
||||||
parser.add_argument("--my_img_version", default=0, type=str)
|
parser.add_argument("--my_img_version", default=0, type=str)
|
||||||
parser.add_argument("--my_img_size", default=0, type=int)
|
parser.add_argument("--my_img_size", default=0, type=int)
|
||||||
parser.add_argument("--my_img_bit", default=0, type=int)
|
parser.add_argument("--my_img_bit", default=0, type=int)
|
||||||
parser.add_argument("--my_img_clip", default='x', type=str)
|
parser.add_argument("--my_img_clip", default="x", type=str)
|
||||||
parser.add_argument("--my_img_clip_scale", default=1, type=float)
|
parser.add_argument("--my_img_clip_scale", default=1, type=float)
|
||||||
parser.add_argument("--my_img_l1_scale", default=0, type=float)
|
parser.add_argument("--my_img_l1_scale", default=0, type=float)
|
||||||
parser.add_argument("--my_img_encoder", default='x', type=str)
|
parser.add_argument("--my_img_encoder", default="x", type=str)
|
||||||
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
|
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
|
||||||
parser.add_argument("--my_sample_len", default=0, type=int)
|
parser.add_argument("--my_sample_len", default=0, type=int)
|
||||||
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
||||||
@ -104,7 +136,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--load_partial", default=0, type=int)
|
parser.add_argument("--load_partial", default=0, type=int)
|
||||||
parser.add_argument("--magic_prime", default=0, type=int)
|
parser.add_argument("--magic_prime", default=0, type=int)
|
||||||
parser.add_argument("--my_qa_mask", default=0, type=int)
|
parser.add_argument("--my_qa_mask", default=0, type=int)
|
||||||
parser.add_argument("--my_testing", default='', type=str)
|
parser.add_argument("--my_testing", default="", type=str)
|
||||||
|
|
||||||
parser.add_argument("--lora", action="store_true")
|
parser.add_argument("--lora", action="store_true")
|
||||||
parser.add_argument("--lora_load", default="", type=str)
|
parser.add_argument("--lora_load", default="", type=str)
|
||||||
@ -122,18 +154,26 @@ if __name__ == "__main__":
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
if "deepspeed" in args.strategy:
|
if "deepspeed" in args.strategy:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
if args.random_seed >= 0:
|
if args.random_seed >= 0:
|
||||||
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
|
print(
|
||||||
|
f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n"
|
||||||
|
* 3
|
||||||
|
)
|
||||||
seed_everything(args.random_seed)
|
seed_everything(args.random_seed)
|
||||||
|
|
||||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
warnings.filterwarnings(
|
||||||
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
"ignore", ".*Consider increasing the value of the `num_workers` argument*"
|
||||||
|
)
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", ".*The progress bar already tracks a metric with the*"
|
||||||
|
)
|
||||||
# os.environ["WDS_SHOW_SEED"] = "1"
|
# os.environ["WDS_SHOW_SEED"] = "1"
|
||||||
|
|
||||||
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
@ -158,7 +198,9 @@ if __name__ == "__main__":
|
|||||||
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
||||||
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
||||||
else:
|
else:
|
||||||
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
args.run_name = (
|
||||||
|
f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
||||||
|
)
|
||||||
if not os.path.exists(args.proj_dir):
|
if not os.path.exists(args.proj_dir):
|
||||||
os.makedirs(args.proj_dir)
|
os.makedirs(args.proj_dir)
|
||||||
|
|
||||||
@ -240,24 +282,40 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
rank_zero_info(str(vars(args)) + "\n")
|
rank_zero_info(str(vars(args)) + "\n")
|
||||||
|
|
||||||
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
|
assert args.data_type in [
|
||||||
|
"utf-8",
|
||||||
|
"utf-16le",
|
||||||
|
"numpy",
|
||||||
|
"binidx",
|
||||||
|
"dummy",
|
||||||
|
"wds_img",
|
||||||
|
"uint16",
|
||||||
|
]
|
||||||
|
|
||||||
if args.lr_final == 0 or args.lr_init == 0:
|
if args.lr_final == 0 or args.lr_init == 0:
|
||||||
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
|
rank_zero_info(
|
||||||
|
"\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
||||||
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
||||||
if args.precision == "fp32":
|
if args.precision == "fp32":
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
|
rank_zero_info(
|
||||||
|
"\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n"
|
||||||
|
)
|
||||||
if args.precision == "fp16":
|
if args.precision == "fp16":
|
||||||
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
rank_zero_info(
|
||||||
|
"\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
os.environ["RWKV_JIT_ON"] = "1"
|
os.environ["RWKV_JIT_ON"] = "1"
|
||||||
if "deepspeed_stage_3" in args.strategy:
|
if "deepspeed_stage_3" in args.strategy:
|
||||||
os.environ["RWKV_JIT_ON"] = "0"
|
os.environ["RWKV_JIT_ON"] = "0"
|
||||||
if args.lora and args.grad_cp == 1:
|
if args.lora and args.grad_cp == 1:
|
||||||
print('!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it')
|
print(
|
||||||
|
"!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it"
|
||||||
|
)
|
||||||
os.environ["RWKV_JIT_ON"] = "0"
|
os.environ["RWKV_JIT_ON"] = "0"
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
@ -284,20 +342,22 @@ if __name__ == "__main__":
|
|||||||
train_data = MyDataset(args)
|
train_data = MyDataset(args)
|
||||||
args.vocab_size = train_data.vocab_size
|
args.vocab_size = train_data.vocab_size
|
||||||
|
|
||||||
if args.data_type == 'wds_img':
|
if args.data_type == "wds_img":
|
||||||
from src.model_img import RWKV_IMG
|
from src.model_img import RWKV_IMG
|
||||||
|
|
||||||
assert args.lora, "LoRA not yet supported for RWKV_IMG"
|
assert args.lora, "LoRA not yet supported for RWKV_IMG"
|
||||||
model = RWKV_IMG(args)
|
model = RWKV_IMG(args)
|
||||||
else:
|
else:
|
||||||
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
||||||
|
|
||||||
if args.lora:
|
if args.lora:
|
||||||
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
||||||
LORA_CONFIG["r"] = args.lora_r
|
LORA_CONFIG["r"] = args.lora_r
|
||||||
LORA_CONFIG["alpha"] = args.lora_alpha
|
LORA_CONFIG["alpha"] = args.lora_alpha
|
||||||
LORA_CONFIG["dropout"] = args.lora_dropout
|
LORA_CONFIG["dropout"] = args.lora_dropout
|
||||||
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(','))
|
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
|
||||||
enable_time_finetune = 'time' in LORA_CONFIG["parts"]
|
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||||
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
|
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||||
model = RWKV(args)
|
model = RWKV(args)
|
||||||
# only train lora parameters
|
# only train lora parameters
|
||||||
if args.lora:
|
if args.lora:
|
||||||
@ -305,20 +365,24 @@ if __name__ == "__main__":
|
|||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
# have to check param name since it may have been wrapped by torchscript
|
# have to check param name since it may have been wrapped by torchscript
|
||||||
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
|
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
|
||||||
print(f' LoRA training module {name}')
|
print(f" LoRA training module {name}")
|
||||||
for pname, param in module.named_parameters():
|
for pname, param in module.named_parameters():
|
||||||
param.requires_grad = 'lora_' in pname
|
param.requires_grad = "lora_" in pname
|
||||||
elif enable_ln_finetune and '.ln' in name:
|
elif enable_ln_finetune and ".ln" in name:
|
||||||
print(f' LoRA additionally training module {name}')
|
print(f" LoRA additionally training module {name}")
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
elif enable_time_finetune and any(n.startswith("time") for n, _ in module.named_parameters()):
|
elif enable_time_finetune and any(
|
||||||
|
n.startswith("time") for n, _ in module.named_parameters()
|
||||||
|
):
|
||||||
for pname, param in module.named_parameters():
|
for pname, param in module.named_parameters():
|
||||||
if pname.startswith("time"):
|
if pname.startswith("time"):
|
||||||
print(f' LoRA additionally training parameter {pname}')
|
print(f" LoRA additionally training parameter {pname}")
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
|
|
||||||
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
|
if (
|
||||||
|
len(args.load_model) == 0 or args.my_pile_stage == 1
|
||||||
|
): # shall we build the initial weights?
|
||||||
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
||||||
generate_init_weight(model, init_weight_name) # save initial weights
|
generate_init_weight(model, init_weight_name) # save initial weights
|
||||||
args.load_model = init_weight_name
|
args.load_model = init_weight_name
|
||||||
@ -346,27 +410,39 @@ if __name__ == "__main__":
|
|||||||
# If using LoRA, the LoRA keys might be missing in the original model
|
# If using LoRA, the LoRA keys might be missing in the original model
|
||||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||||
if os.path.isfile(args.lora_load):
|
if os.path.isfile(args.lora_load):
|
||||||
model.load_state_dict(torch.load(args.lora_load, map_location="cpu"),
|
model.load_state_dict(
|
||||||
strict=False)
|
torch.load(args.lora_load, map_location="cpu"), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
trainer: Trainer = Trainer.from_argparse_args(
|
trainer: Trainer = Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
callbacks=[train_callback(args)],
|
callbacks=[train_callback(args)],
|
||||||
)
|
)
|
||||||
|
|
||||||
if (args.lr_init > 1e-4 or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8):
|
if (
|
||||||
if 'I_KNOW_WHAT_IM_DOING' in os.environ:
|
args.lr_init > 1e-4
|
||||||
|
or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
|
||||||
|
):
|
||||||
|
if "I_KNOW_WHAT_IM_DOING" in os.environ:
|
||||||
if trainer.global_rank == 0:
|
if trainer.global_rank == 0:
|
||||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
print(f' WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
|
print(
|
||||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
f" WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
|
||||||
|
)
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
else:
|
else:
|
||||||
if trainer.global_rank == 0:
|
if trainer.global_rank == 0:
|
||||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
print(f' ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
|
print(
|
||||||
print(f' Unless you are sure this is what you want, adjust them accordingly')
|
f" ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
|
||||||
print(f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")')
|
)
|
||||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
print(
|
||||||
|
f" Unless you are sure this is what you want, adjust them accordingly"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")'
|
||||||
|
)
|
||||||
|
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
if trainer.global_rank == 0:
|
if trainer.global_rank == 0:
|
||||||
@ -379,10 +455,22 @@ if __name__ == "__main__":
|
|||||||
print(f"{str(shape[0]).ljust(5)} {n}")
|
print(f"{str(shape[0]).ljust(5)} {n}")
|
||||||
|
|
||||||
if "deepspeed" in args.strategy:
|
if "deepspeed" in args.strategy:
|
||||||
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
|
||||||
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
args.ds_bucket_mb * 1000 * 1000
|
||||||
|
)
|
||||||
|
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
|
||||||
|
args.ds_bucket_mb * 1000 * 1000
|
||||||
|
)
|
||||||
|
|
||||||
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
||||||
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
data_loader = DataLoader(
|
||||||
|
train_data,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=True,
|
||||||
|
batch_size=args.micro_bsz,
|
||||||
|
num_workers=1,
|
||||||
|
persistent_workers=False,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
trainer.fit(model, data_loader)
|
trainer.fit(model, data_loader)
|
||||||
|
Loading…
Reference in New Issue
Block a user