This commit is contained in:
184
finetune/lora/v6/train.py
vendored
184
finetune/lora/v6/train.py
vendored
@@ -1,6 +1,7 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
import os
|
||||
|
||||
import logging
|
||||
|
||||
@@ -110,7 +111,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--magic_prime", default=0, type=int)
|
||||
parser.add_argument("--my_qa_mask", default=0, type=int)
|
||||
parser.add_argument("--my_random_steps", default=0, type=int)
|
||||
parser.add_argument("--my_testing", default="", type=str)
|
||||
parser.add_argument("--my_testing", default="x052", type=str)
|
||||
parser.add_argument("--my_exit", default=99999999, type=int)
|
||||
parser.add_argument("--my_exit_tokens", default=0, type=int)
|
||||
|
||||
@@ -123,6 +124,29 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--lora_dropout", default=0.01, type=float)
|
||||
parser.add_argument("--lora_parts", default="att,ln,time", type=str)
|
||||
|
||||
# LISA
|
||||
parser.add_argument("--LISA", action="store_true")
|
||||
parser.add_argument("--lisa_r", default=2, type=int)
|
||||
parser.add_argument("--lisa_k", default=100, type=int)
|
||||
|
||||
# PISSA
|
||||
parser.add_argument("--PISSA", action="store_true")
|
||||
parser.add_argument("--svd_niter", default=4, type=int)
|
||||
|
||||
# quant
|
||||
parser.add_argument("--quant", default="none", type=str)
|
||||
|
||||
# dataset
|
||||
parser.add_argument("--dataload", default="get", type=str)
|
||||
|
||||
# state tuning
|
||||
parser.add_argument("--state_tune", action="store_true")
|
||||
|
||||
parser.add_argument("--chunk_ctx", default=512, type=int)
|
||||
# fla
|
||||
parser.add_argument("--fla", action="store_true")
|
||||
parser.add_argument("--train_type", default="none", type=str)
|
||||
|
||||
if pl.__version__[0] == "2":
|
||||
parser.add_argument("--accelerator", default="gpu", type=str)
|
||||
parser.add_argument("--strategy", default="auto", type=str)
|
||||
@@ -175,6 +199,14 @@ if __name__ == "__main__":
|
||||
os.environ["RWKV_MY_TESTING"] = args.my_testing
|
||||
os.environ["RWKV_CTXLEN"] = str(args.ctx_len)
|
||||
os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
|
||||
######state tuning
|
||||
os.environ["RWKV_TRAIN_TYPE"] = ""
|
||||
if args.train_type == "state":
|
||||
os.environ["RWKV_TRAIN_TYPE"] = "states"
|
||||
elif args.train_type == "infctx":
|
||||
os.environ["RWKV_TRAIN_TYPE"] = "infctx"
|
||||
|
||||
os.environ["WKV"] = "fla" if args.fla else ""
|
||||
if args.dim_att <= 0:
|
||||
args.dim_att = args.n_embd
|
||||
if args.dim_ffn <= 0:
|
||||
@@ -323,11 +355,68 @@ if __name__ == "__main__":
|
||||
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||
model = RWKV(args)
|
||||
|
||||
if args.lora:
|
||||
freeze = False
|
||||
if args.lora or args.LISA or args.train_type == "state":
|
||||
model.requires_grad_(False)
|
||||
for name, module in model.named_modules():
|
||||
freeze = True
|
||||
|
||||
if args.state_tune or args.train_type == "state":
|
||||
for name, module in model.named_modules():
|
||||
for pname, param in module.named_parameters():
|
||||
if "state" in pname:
|
||||
param.requires_grad = True
|
||||
break
|
||||
|
||||
if args.LISA:
|
||||
import re
|
||||
|
||||
select_layers = np.random.choice(
|
||||
range(args.n_layer), args.lisa_r, replace=False
|
||||
)
|
||||
for name, module in model.named_modules():
|
||||
for pname, param in module.named_parameters():
|
||||
if (
|
||||
"emb" in pname
|
||||
or "head" in pname
|
||||
or ".ln" in pname
|
||||
or "time" in pname
|
||||
):
|
||||
param.requires_grad = True
|
||||
match = re.search(r"\d+", pname)
|
||||
if match:
|
||||
number = int(match.group())
|
||||
if number in select_layers:
|
||||
param.requires_grad = True
|
||||
break
|
||||
|
||||
elif args.lora:
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if len(args.load_model) == 0:
|
||||
if any(n.startswith("emb.") for n, _ in module.named_parameters()):
|
||||
for pname, param in module.named_parameters():
|
||||
if "emb.weight" == pname:
|
||||
print(f" EMB additionally training module {pname}")
|
||||
param.requires_grad = True
|
||||
if any(n.startswith("head.") for n, _ in module.named_parameters()):
|
||||
for pname, param in module.named_parameters():
|
||||
if "head.weight" == pname:
|
||||
print(f" head additionally training module {pname}")
|
||||
param.requires_grad = True
|
||||
if "ln" in name:
|
||||
print(f" LoRA additionally training module {name}")
|
||||
for param in module.parameters():
|
||||
param.requires_grad = True
|
||||
if any(n.startswith("emb.") for n, _ in module.named_parameters()):
|
||||
for pname, param in module.named_parameters():
|
||||
if args.emb and "emb.weight" == pname:
|
||||
print(f" EMB additionally training module {pname}")
|
||||
param.requires_grad = True
|
||||
if any(n.startswith("head.") for n, _ in module.named_parameters()):
|
||||
for pname, param in module.named_parameters():
|
||||
if args.emb and "head.weight" == pname:
|
||||
print(f" head additionally training module {pname}")
|
||||
param.requires_grad = True
|
||||
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
|
||||
print(f" LoRA additionally training module {name}")
|
||||
for pname, param in module.named_parameters():
|
||||
@@ -376,11 +465,26 @@ if __name__ == "__main__":
|
||||
for k in model.state_dict():
|
||||
if k not in load_keys:
|
||||
load_dict[k] = model.state_dict()[k]
|
||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
model.load_state_dict(load_dict, strict=(not freeze))
|
||||
if os.path.isfile(args.lora_load):
|
||||
model.load_state_dict(
|
||||
torch.load(args.lora_load, map_location="cpu"), strict=False
|
||||
)
|
||||
if args.PISSA:
|
||||
init_dict = {}
|
||||
rank_zero_info(f"########## Init PISSA... ##########")
|
||||
for name, m in model.named_modules():
|
||||
if hasattr(m, "pissa_init") and callable(getattr(m, "pissa_init")):
|
||||
m.pissa_init(args.svd_niter)
|
||||
init_dict[f"{name}.init_lora_A"] = m.lora_A.data
|
||||
init_dict[f"{name}.init_lora_B"] = m.lora_B.data
|
||||
torch.save(init_dict, f"{args.proj_dir}/init_lora.pth")
|
||||
|
||||
if args.quant != "none":
|
||||
rank_zero_info(f"########## Quant... ##########")
|
||||
for name, m in model.named_modules():
|
||||
if hasattr(m, "quant") and callable(getattr(m, "quant")):
|
||||
m.quant(args.quant)
|
||||
|
||||
if pl.__version__[0] == "2":
|
||||
trainer = Trainer(
|
||||
@@ -434,3 +538,73 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
trainer.fit(model, data_loader)
|
||||
# if args.LISA:
|
||||
# args.load_model=f'rwkv-0.pth'
|
||||
# model = RWKV(args)
|
||||
# model.requires_grad_(False)
|
||||
|
||||
# select_layers = np.random.choice(range(args.n_layer), args.lisa_r, replace=False)
|
||||
# for name, module in model.named_modules():
|
||||
# for pname, param in module.named_parameters():
|
||||
# if 'emb' in pname or 'head' in pname or '.ln' in pname or 'time' in pname :
|
||||
# param.requires_grad = True
|
||||
# match = re.search(r'\d+', pname)
|
||||
# if match:
|
||||
# number = int(match.group())
|
||||
# if number in select_layers:
|
||||
# param.requires_grad = True
|
||||
# break
|
||||
# rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
||||
# try:
|
||||
# load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
# load_keys = list(load_dict.keys())
|
||||
# for k in load_keys:
|
||||
# if k.startswith('_forward_module.'):
|
||||
# load_dict[k.replace('_forward_module.','')] = load_dict[k]
|
||||
# del load_dict[k]
|
||||
# except:
|
||||
# rank_zero_info(f"Bad checkpoint {args.load_model}")
|
||||
# if args.my_pile_stage >= 2: # try again using another checkpoint
|
||||
# max_p = args.my_pile_prev_p
|
||||
# if max_p == -1:
|
||||
# args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||
# else:
|
||||
# args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||
# args.epoch_begin = max_p + 1
|
||||
# rank_zero_info(f"Trying {args.load_model}")
|
||||
# load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
|
||||
# if args.load_partial == 1:
|
||||
# load_keys = load_dict.keys()
|
||||
# for k in model.state_dict():
|
||||
# if k not in load_keys:
|
||||
# load_dict[k] = model.state_dict()[k]
|
||||
# model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
|
||||
# if pl.__version__[0]=='2':
|
||||
# trainer = Trainer(accelerator=args.accelerator,strategy=args.strategy,devices=args.devices,num_nodes=args.num_nodes,precision=args.precision,
|
||||
# logger=args.logger,callbacks=[train_callback(args)],max_epochs=args.max_epochs,check_val_every_n_epoch=args.check_val_every_n_epoch,num_sanity_val_steps=args.num_sanity_val_steps,
|
||||
# log_every_n_steps=args.log_every_n_steps,enable_checkpointing=args.enable_checkpointing,accumulate_grad_batches=args.accumulate_grad_batches,gradient_clip_val=args.gradient_clip_val)
|
||||
# else:
|
||||
# trainer = Trainer.from_argparse_args(
|
||||
# args,
|
||||
# callbacks=[train_callback(args)],
|
||||
# )
|
||||
|
||||
# if trainer.global_rank == 0:
|
||||
# for n in model.state_dict():
|
||||
# shape = model.state_dict()[n].shape
|
||||
# shape = [i for i in shape if i != 1]
|
||||
# if len(shape) > 1:
|
||||
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
|
||||
# else:
|
||||
# print(f"{str(shape[0]).ljust(5)} {n}")
|
||||
|
||||
# if "deepspeed" in args.strategy:
|
||||
# trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
||||
# trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
||||
|
||||
# # must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
||||
# data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
||||
|
||||
# trainer.fit(model, data_loader)
|
||||
|
||||
Reference in New Issue
Block a user