This commit is contained in:
josc146 2023-07-09 12:32:50 +08:00
parent e930eb5967
commit d8c70453ec

196
finetune/lora/train.py vendored
View File

@ -50,52 +50,84 @@ if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--load_model", default="", type=str) # full path, with .pth parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb parser.add_argument(
"--wandb", default="", type=str
) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str) parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int) parser.add_argument("--random_seed", default="-1", type=int)
parser.add_argument("--data_file", default="", type=str) parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str) parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) parser.add_argument(
"--vocab_size", default=0, type=int
) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument("--ctx_len", default=1024, type=int) parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps parser.add_argument(
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final "--epoch_steps", default=1000, type=int
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x ) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" parser.add_argument(
"--epoch_count", default=500, type=int
) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument(
"--epoch_begin", default=0, type=int
) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument(
"--epoch_save", default=5, type=int
) # save the model every [epoch_save] "epochs"
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) parser.add_argument(
"--micro_bsz", default=12, type=int
) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int) parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int) parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) parser.add_argument(
"--pre_ffn", default=0, type=int
) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer parser.add_argument(
"--tiny_att_layer", default=-999, type=int
) # tiny attention @ which layer
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument(
"--lr_init", default=6e-4, type=float
) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float) parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model parser.add_argument(
"--warmup_steps", default=0, type=int
) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence parser.add_argument(
"--beta2", default=0.99, type=float
) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower parser.add_argument(
"--grad_cp", default=0, type=int
) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift parser.add_argument(
"--my_pile_shift", default=-1, type=int
) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int) parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) parser.add_argument(
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough "--layerwise_lr", default=1, type=int
) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument(
"--ds_bucket_mb", default=200, type=int
) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
parser.add_argument("--my_img_version", default=0, type=str) parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int) parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int) parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str) parser.add_argument("--my_img_clip", default="x", type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float) parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float) parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str) parser.add_argument("--my_img_encoder", default="x", type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float) # parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int) parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int) parser.add_argument("--my_ffn_shift", default=1, type=int)
@ -104,7 +136,7 @@ if __name__ == "__main__":
parser.add_argument("--load_partial", default=0, type=int) parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int) parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int) parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str) parser.add_argument("--my_testing", default="", type=str)
parser.add_argument("--lora", action="store_true") parser.add_argument("--lora", action="store_true")
parser.add_argument("--lora_load", default="", type=str) parser.add_argument("--lora_load", default="", type=str)
@ -122,18 +154,26 @@ if __name__ == "__main__":
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
if "deepspeed" in args.strategy: if "deepspeed" in args.strategy:
import deepspeed import deepspeed
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
if args.random_seed >= 0: if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) print(
f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n"
* 3
)
seed_everything(args.random_seed) seed_everything(args.random_seed)
np.set_printoptions(precision=4, suppress=True, linewidth=200) np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") warnings.filterwarnings(
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") "ignore", ".*Consider increasing the value of the `num_workers` argument*"
)
warnings.filterwarnings(
"ignore", ".*The progress bar already tracks a metric with the*"
)
# os.environ["WDS_SHOW_SEED"] = "1" # os.environ["WDS_SHOW_SEED"] = "1"
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
@ -158,7 +198,9 @@ if __name__ == "__main__":
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
args.proj_dir = f"{args.proj_dir}-{args.run_name}" args.proj_dir = f"{args.proj_dir}-{args.run_name}"
else: else:
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" args.run_name = (
f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
)
if not os.path.exists(args.proj_dir): if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir) os.makedirs(args.proj_dir)
@ -240,24 +282,40 @@ if __name__ == "__main__":
) )
rank_zero_info(str(vars(args)) + "\n") rank_zero_info(str(vars(args)) + "\n")
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] assert args.data_type in [
"utf-8",
"utf-16le",
"numpy",
"binidx",
"dummy",
"wds_img",
"uint16",
]
if args.lr_final == 0 or args.lr_init == 0: if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") rank_zero_info(
"\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n"
)
assert args.precision in ["fp32", "tf32", "fp16", "bf16"] assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32": if args.precision == "fp32":
for i in range(10): for i in range(10):
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") rank_zero_info(
"\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n"
)
if args.precision == "fp16": if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") rank_zero_info(
"\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n"
)
os.environ["RWKV_JIT_ON"] = "1" os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy: if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0" os.environ["RWKV_JIT_ON"] = "0"
if args.lora and args.grad_cp == 1: if args.lora and args.grad_cp == 1:
print('!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it') print(
"!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it"
)
os.environ["RWKV_JIT_ON"] = "0" os.environ["RWKV_JIT_ON"] = "0"
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -284,20 +342,22 @@ if __name__ == "__main__":
train_data = MyDataset(args) train_data = MyDataset(args)
args.vocab_size = train_data.vocab_size args.vocab_size = train_data.vocab_size
if args.data_type == 'wds_img': if args.data_type == "wds_img":
from src.model_img import RWKV_IMG from src.model_img import RWKV_IMG
assert args.lora, "LoRA not yet supported for RWKV_IMG" assert args.lora, "LoRA not yet supported for RWKV_IMG"
model = RWKV_IMG(args) model = RWKV_IMG(args)
else: else:
from src.model import RWKV, LORA_CONFIG, LoraLinear from src.model import RWKV, LORA_CONFIG, LoraLinear
if args.lora: if args.lora:
assert args.lora_r > 0, "LoRA should have its `r` > 0" assert args.lora_r > 0, "LoRA should have its `r` > 0"
LORA_CONFIG["r"] = args.lora_r LORA_CONFIG["r"] = args.lora_r
LORA_CONFIG["alpha"] = args.lora_alpha LORA_CONFIG["alpha"] = args.lora_alpha
LORA_CONFIG["dropout"] = args.lora_dropout LORA_CONFIG["dropout"] = args.lora_dropout
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(',')) LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
enable_time_finetune = 'time' in LORA_CONFIG["parts"] enable_time_finetune = "time" in LORA_CONFIG["parts"]
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"] enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
model = RWKV(args) model = RWKV(args)
# only train lora parameters # only train lora parameters
if args.lora: if args.lora:
@ -305,20 +365,24 @@ if __name__ == "__main__":
for name, module in model.named_modules(): for name, module in model.named_modules():
# have to check param name since it may have been wrapped by torchscript # have to check param name since it may have been wrapped by torchscript
if any(n.startswith("lora_") for n, _ in module.named_parameters()): if any(n.startswith("lora_") for n, _ in module.named_parameters()):
print(f' LoRA training module {name}') print(f" LoRA training module {name}")
for pname, param in module.named_parameters(): for pname, param in module.named_parameters():
param.requires_grad = 'lora_' in pname param.requires_grad = "lora_" in pname
elif enable_ln_finetune and '.ln' in name: elif enable_ln_finetune and ".ln" in name:
print(f' LoRA additionally training module {name}') print(f" LoRA additionally training module {name}")
for param in module.parameters(): for param in module.parameters():
param.requires_grad = True param.requires_grad = True
elif enable_time_finetune and any(n.startswith("time") for n, _ in module.named_parameters()): elif enable_time_finetune and any(
n.startswith("time") for n, _ in module.named_parameters()
):
for pname, param in module.named_parameters(): for pname, param in module.named_parameters():
if pname.startswith("time"): if pname.startswith("time"):
print(f' LoRA additionally training parameter {pname}') print(f" LoRA additionally training parameter {pname}")
param.requires_grad = True param.requires_grad = True
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights? if (
len(args.load_model) == 0 or args.my_pile_stage == 1
): # shall we build the initial weights?
init_weight_name = f"{args.proj_dir}/rwkv-init.pth" init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, init_weight_name) # save initial weights generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name args.load_model = init_weight_name
@ -346,27 +410,39 @@ if __name__ == "__main__":
# If using LoRA, the LoRA keys might be missing in the original model # If using LoRA, the LoRA keys might be missing in the original model
model.load_state_dict(load_dict, strict=(not args.lora)) model.load_state_dict(load_dict, strict=(not args.lora))
if os.path.isfile(args.lora_load): if os.path.isfile(args.lora_load):
model.load_state_dict(torch.load(args.lora_load, map_location="cpu"), model.load_state_dict(
strict=False) torch.load(args.lora_load, map_location="cpu"), strict=False
)
trainer: Trainer = Trainer.from_argparse_args( trainer: Trainer = Trainer.from_argparse_args(
args, args,
callbacks=[train_callback(args)], callbacks=[train_callback(args)],
) )
if (args.lr_init > 1e-4 or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8): if (
if 'I_KNOW_WHAT_IM_DOING' in os.environ: args.lr_init > 1e-4
or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
):
if "I_KNOW_WHAT_IM_DOING" in os.environ:
if trainer.global_rank == 0: if trainer.global_rank == 0:
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print(f' WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)') print(
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') f" WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
)
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
else: else:
if trainer.global_rank == 0: if trainer.global_rank == 0:
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print(f' ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)') print(
print(f' Unless you are sure this is what you want, adjust them accordingly') f" ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
print(f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")') )
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') print(
f" Unless you are sure this is what you want, adjust them accordingly"
)
print(
f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")'
)
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
exit(0) exit(0)
if trainer.global_rank == 0: if trainer.global_rank == 0:
@ -379,10 +455,22 @@ if __name__ == "__main__":
print(f"{str(shape[0]).ljust(5)} {n}") print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy: if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 args.ds_bucket_mb * 1000 * 1000
)
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
args.ds_bucket_mb * 1000 * 1000
)
# must set shuffle=False, persistent_workers=False (because worker is in another thread) # must set shuffle=False, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) data_loader = DataLoader(
train_data,
shuffle=False,
pin_memory=True,
batch_size=args.micro_bsz,
num_workers=1,
persistent_workers=False,
drop_last=True,
)
trainer.fit(model, data_loader) trainer.fit(model, data_loader)