This commit is contained in:
josc146
2024-05-28 22:35:47 +08:00
parent 3488d22d22
commit f05a4acb04
138 changed files with 29047 additions and 334 deletions

View File

@@ -1,6 +1,7 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os
import logging
@@ -110,7 +111,7 @@ if __name__ == "__main__":
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_random_steps", default=0, type=int)
parser.add_argument("--my_testing", default="", type=str)
parser.add_argument("--my_testing", default="x052", type=str)
parser.add_argument("--my_exit", default=99999999, type=int)
parser.add_argument("--my_exit_tokens", default=0, type=int)
@@ -123,6 +124,29 @@ if __name__ == "__main__":
parser.add_argument("--lora_dropout", default=0.01, type=float)
parser.add_argument("--lora_parts", default="att,ln,time", type=str)
# LISA
parser.add_argument("--LISA", action="store_true")
parser.add_argument("--lisa_r", default=2, type=int)
parser.add_argument("--lisa_k", default=100, type=int)
# PISSA
parser.add_argument("--PISSA", action="store_true")
parser.add_argument("--svd_niter", default=4, type=int)
# quant
parser.add_argument("--quant", default="none", type=str)
# dataset
parser.add_argument("--dataload", default="get", type=str)
# state tuning
parser.add_argument("--state_tune", action="store_true")
parser.add_argument("--chunk_ctx", default=512, type=int)
# fla
parser.add_argument("--fla", action="store_true")
parser.add_argument("--train_type", default="none", type=str)
if pl.__version__[0] == "2":
parser.add_argument("--accelerator", default="gpu", type=str)
parser.add_argument("--strategy", default="auto", type=str)
@@ -175,6 +199,14 @@ if __name__ == "__main__":
os.environ["RWKV_MY_TESTING"] = args.my_testing
os.environ["RWKV_CTXLEN"] = str(args.ctx_len)
os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
######state tuning
os.environ["RWKV_TRAIN_TYPE"] = ""
if args.train_type == "state":
os.environ["RWKV_TRAIN_TYPE"] = "states"
elif args.train_type == "infctx":
os.environ["RWKV_TRAIN_TYPE"] = "infctx"
os.environ["WKV"] = "fla" if args.fla else ""
if args.dim_att <= 0:
args.dim_att = args.n_embd
if args.dim_ffn <= 0:
@@ -323,11 +355,68 @@ if __name__ == "__main__":
enable_time_finetune = "time" in LORA_CONFIG["parts"]
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
model = RWKV(args)
if args.lora:
freeze = False
if args.lora or args.LISA or args.train_type == "state":
model.requires_grad_(False)
for name, module in model.named_modules():
freeze = True
if args.state_tune or args.train_type == "state":
for name, module in model.named_modules():
for pname, param in module.named_parameters():
if "state" in pname:
param.requires_grad = True
break
if args.LISA:
import re
select_layers = np.random.choice(
range(args.n_layer), args.lisa_r, replace=False
)
for name, module in model.named_modules():
for pname, param in module.named_parameters():
if (
"emb" in pname
or "head" in pname
or ".ln" in pname
or "time" in pname
):
param.requires_grad = True
match = re.search(r"\d+", pname)
if match:
number = int(match.group())
if number in select_layers:
param.requires_grad = True
break
elif args.lora:
for name, module in model.named_modules():
if len(args.load_model) == 0:
if any(n.startswith("emb.") for n, _ in module.named_parameters()):
for pname, param in module.named_parameters():
if "emb.weight" == pname:
print(f" EMB additionally training module {pname}")
param.requires_grad = True
if any(n.startswith("head.") for n, _ in module.named_parameters()):
for pname, param in module.named_parameters():
if "head.weight" == pname:
print(f" head additionally training module {pname}")
param.requires_grad = True
if "ln" in name:
print(f" LoRA additionally training module {name}")
for param in module.parameters():
param.requires_grad = True
if any(n.startswith("emb.") for n, _ in module.named_parameters()):
for pname, param in module.named_parameters():
if args.emb and "emb.weight" == pname:
print(f" EMB additionally training module {pname}")
param.requires_grad = True
if any(n.startswith("head.") for n, _ in module.named_parameters()):
for pname, param in module.named_parameters():
if args.emb and "head.weight" == pname:
print(f" head additionally training module {pname}")
param.requires_grad = True
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
print(f" LoRA additionally training module {name}")
for pname, param in module.named_parameters():
@@ -376,11 +465,26 @@ if __name__ == "__main__":
for k in model.state_dict():
if k not in load_keys:
load_dict[k] = model.state_dict()[k]
model.load_state_dict(load_dict, strict=(not args.lora))
model.load_state_dict(load_dict, strict=(not freeze))
if os.path.isfile(args.lora_load):
model.load_state_dict(
torch.load(args.lora_load, map_location="cpu"), strict=False
)
if args.PISSA:
init_dict = {}
rank_zero_info(f"########## Init PISSA... ##########")
for name, m in model.named_modules():
if hasattr(m, "pissa_init") and callable(getattr(m, "pissa_init")):
m.pissa_init(args.svd_niter)
init_dict[f"{name}.init_lora_A"] = m.lora_A.data
init_dict[f"{name}.init_lora_B"] = m.lora_B.data
torch.save(init_dict, f"{args.proj_dir}/init_lora.pth")
if args.quant != "none":
rank_zero_info(f"########## Quant... ##########")
for name, m in model.named_modules():
if hasattr(m, "quant") and callable(getattr(m, "quant")):
m.quant(args.quant)
if pl.__version__[0] == "2":
trainer = Trainer(
@@ -434,3 +538,73 @@ if __name__ == "__main__":
)
trainer.fit(model, data_loader)
# if args.LISA:
# args.load_model=f'rwkv-0.pth'
# model = RWKV(args)
# model.requires_grad_(False)
# select_layers = np.random.choice(range(args.n_layer), args.lisa_r, replace=False)
# for name, module in model.named_modules():
# for pname, param in module.named_parameters():
# if 'emb' in pname or 'head' in pname or '.ln' in pname or 'time' in pname :
# param.requires_grad = True
# match = re.search(r'\d+', pname)
# if match:
# number = int(match.group())
# if number in select_layers:
# param.requires_grad = True
# break
# rank_zero_info(f"########## Loading {args.load_model}... ##########")
# try:
# load_dict = torch.load(args.load_model, map_location="cpu")
# load_keys = list(load_dict.keys())
# for k in load_keys:
# if k.startswith('_forward_module.'):
# load_dict[k.replace('_forward_module.','')] = load_dict[k]
# del load_dict[k]
# except:
# rank_zero_info(f"Bad checkpoint {args.load_model}")
# if args.my_pile_stage >= 2: # try again using another checkpoint
# max_p = args.my_pile_prev_p
# if max_p == -1:
# args.load_model = f"{args.proj_dir}/rwkv-init.pth"
# else:
# args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
# args.epoch_begin = max_p + 1
# rank_zero_info(f"Trying {args.load_model}")
# load_dict = torch.load(args.load_model, map_location="cpu")
# if args.load_partial == 1:
# load_keys = load_dict.keys()
# for k in model.state_dict():
# if k not in load_keys:
# load_dict[k] = model.state_dict()[k]
# model.load_state_dict(load_dict, strict=(not args.lora))
# if pl.__version__[0]=='2':
# trainer = Trainer(accelerator=args.accelerator,strategy=args.strategy,devices=args.devices,num_nodes=args.num_nodes,precision=args.precision,
# logger=args.logger,callbacks=[train_callback(args)],max_epochs=args.max_epochs,check_val_every_n_epoch=args.check_val_every_n_epoch,num_sanity_val_steps=args.num_sanity_val_steps,
# log_every_n_steps=args.log_every_n_steps,enable_checkpointing=args.enable_checkpointing,accumulate_grad_batches=args.accumulate_grad_batches,gradient_clip_val=args.gradient_clip_val)
# else:
# trainer = Trainer.from_argparse_args(
# args,
# callbacks=[train_callback(args)],
# )
# if trainer.global_rank == 0:
# for n in model.state_dict():
# shape = model.state_dict()[n].shape
# shape = [i for i in shape if i != 1]
# if len(shape) > 1:
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
# else:
# print(f"{str(shape[0]).ljust(5)} {n}")
# if "deepspeed" in args.strategy:
# trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
# # must set shuffle=False, persistent_workers=False (because worker is in another thread)
# data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
# trainer.fit(model, data_loader)