diff --git a/finetune/get_layer_and_embd.py b/finetune/get_layer_and_embd.py index 04498aa..04e501a 100644 --- a/finetune/get_layer_and_embd.py +++ b/finetune/get_layer_and_embd.py @@ -32,6 +32,7 @@ cleaner_thread.start() w = torch.load(model_file, map_location="cpu") gc.collect() +vocab_size = w["emb.weight"].shape[0] n_embd = w["emb.weight"].shape[1] n_layer = 0 keys = list(w.keys()) @@ -52,6 +53,9 @@ for x in keys: version = max(6, version) if version <= expected_max_version: - print(f"--n_layer {n_layer} --n_embd {n_embd}", end="") + print( + f"v{int(version)}/train.py --vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}", + end="", + ) else: raise Exception(f"RWKV{version} is not supported") diff --git a/finetune/install-wsl-dep-and-train.sh b/finetune/install-wsl-dep-and-train.sh index 83739f2..461edb3 100644 --- a/finetune/install-wsl-dep-and-train.sh +++ b/finetune/install-wsl-dep-and-train.sh @@ -47,10 +47,10 @@ else fi echo "loading $loadModel" -modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 4) +modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 5.2) echo $modelInfo if [[ $modelInfo =~ "--n_layer" ]]; then - python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \ + python3 ./finetune/lora/$modelInfo $@ --proj_dir lora-models --data_type binidx --lora \ --lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu else echo "modelInfo is invalid" diff --git a/finetune/lora/cuda/wkv_cuda.cu b/finetune/lora/v4/cuda/wkv_cuda.cu similarity index 100% rename from finetune/lora/cuda/wkv_cuda.cu rename to finetune/lora/v4/cuda/wkv_cuda.cu diff --git a/finetune/lora/cuda/wkv_cuda_bf16.cu b/finetune/lora/v4/cuda/wkv_cuda_bf16.cu similarity index 100% rename from finetune/lora/cuda/wkv_cuda_bf16.cu rename to finetune/lora/v4/cuda/wkv_cuda_bf16.cu diff --git a/finetune/lora/cuda/wkv_op.cpp b/finetune/lora/v4/cuda/wkv_op.cpp similarity index 100% rename from finetune/lora/cuda/wkv_op.cpp rename to finetune/lora/v4/cuda/wkv_op.cpp diff --git a/finetune/lora/cuda/wkv_op_bf16.cpp b/finetune/lora/v4/cuda/wkv_op_bf16.cpp similarity index 100% rename from finetune/lora/cuda/wkv_op_bf16.cpp rename to finetune/lora/v4/cuda/wkv_op_bf16.cpp diff --git a/finetune/lora/src/__init__.py b/finetune/lora/v4/src/__init__.py similarity index 100% rename from finetune/lora/src/__init__.py rename to finetune/lora/v4/src/__init__.py diff --git a/finetune/lora/src/binidx.py b/finetune/lora/v4/src/binidx.py similarity index 98% rename from finetune/lora/src/binidx.py rename to finetune/lora/v4/src/binidx.py index 369081a..8d5b40b 100644 --- a/finetune/lora/src/binidx.py +++ b/finetune/lora/v4/src/binidx.py @@ -7,6 +7,7 @@ import struct from functools import lru_cache from itertools import accumulate + def print_rank_0(*message): pass # """If distributed is initialized print only on rank 0.""" @@ -16,12 +17,14 @@ def print_rank_0(*message): # else: # print(*message, flush=True) + def _warmup_mmap_file(path): pass # with open(path, "rb") as stream: # while stream.read(100 * 1024 * 1024): # pass + dtypes = { 1: np.uint8, 2: np.int8, @@ -33,18 +36,22 @@ dtypes = { 8: np.uint16, } + def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) + def index_file_path(prefix_path): return prefix_path + ".idx" + def data_file_path(prefix_path): return prefix_path + ".bin" + class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b"MMIDIDX\x00\x00" @@ -100,7 +107,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): self._file.close() return _Writer() - + def __init__(self, path, skip_warmup=False): with open(path, "rb") as stream: magic_test = stream.read(9) @@ -217,8 +224,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): elif isinstance(idx, slice): start, stop, step = idx.indices(len(self)) if step != 1: - raise ValueError( - "Slices into indexed_dataset must be contiguous") + raise ValueError("Slices into indexed_dataset must be contiguous") ptr = self._index._pointers[start] sizes = self._index._sizes[idx] offsets = list(accumulate(sizes)) diff --git a/finetune/lora/src/dataset.py b/finetune/lora/v4/src/dataset.py similarity index 72% rename from finetune/lora/src/dataset.py rename to finetune/lora/v4/src/dataset.py index 54e3865..88ddf6d 100644 --- a/finetune/lora/src/dataset.py +++ b/finetune/lora/v4/src/dataset.py @@ -17,9 +17,11 @@ class MyDataset(Dataset): if args.data_type == "binidx": self.vocab_size = args.vocab_size - rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") + rank_zero_info( + f"Current vocab size = {self.vocab_size} (make sure it's correct)" + ) - if args.data_file.endswith('/'): + if args.data_file.endswith("/"): d_all = [] for p in os.listdir(args.data_file): if p.endswith(".idx"): @@ -29,33 +31,52 @@ class MyDataset(Dataset): exit(0) else: self.data = MMapIndexedDataset(args.data_file) - self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size + self.data_size = ( + len(self.data._bin_buffer) // self.data._index._dtype_size + ) rank_zero_info(f"Data has {self.data_size} tokens.") if args.my_qa_mask > 0: - self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') - self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size + self.data_pile = MMapIndexedDataset( + "/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document" + ) + self.data_pile_size = ( + len(self.data_pile._bin_buffer) // self.data._index._dtype_size + ) if args.my_pile_stage > 0: # assert self.data_size == 332115325534 and self.vocab_size == 50277 self.samples_per_epoch = args.epoch_steps * args.real_bsz assert self.samples_per_epoch == 40320 - rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") + rank_zero_info( + f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########" + ) dataset_slot = self.data_size // args.ctx_len if args.my_pile_stage != 4: assert MaybeIsPrime(args.magic_prime) assert args.magic_prime % 3 == 2 - assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 + assert ( + args.magic_prime / dataset_slot > 0.99 + and args.magic_prime / dataset_slot <= 1 + ) elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size - rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info( + "Current vocab size =", self.vocab_size, "(make sure it's correct)" + ) self.data_size = len(self.data) rank_zero_info(f"Data has {self.data_size} tokens.") elif args.data_type == "uint16": - self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) + self.data = ( + np.fromfile(args.data_file, dtype=np.uint16) + .astype("int32") + .reshape(-1, args.my_sample_len) + ) self.vocab_size = args.vocab_size - rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info( + "Current vocab size =", self.vocab_size, "(make sure it's correct)" + ) self.data_size = self.data.shape[0] rank_zero_info(f"Data has {self.data_size} samples.") elif args.data_type == "wds_img": @@ -86,10 +107,14 @@ class MyDataset(Dataset): for u in unique: xxObj[xx] = u xx += 1 - with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: + with open( + f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le" + ) as vocab_file: vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) self.data_size = len(self.data) - rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.") + rank_zero_info( + f"Data has {self.data_size} tokens, {self.vocab_size} vocab size." + ) self.stoi = {ch: i for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)} @@ -104,36 +129,53 @@ class MyDataset(Dataset): # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") if args.data_type == "wds_img": + def init_wds(self, bias=0): def identity(x): - return x + return x + import webdataset as wds import torchvision.transforms as transforms + # img_transform = transforms.Compose( # [transforms.CenterCrop(256)] # ) - img_transform = transforms.Compose([ - transforms.CenterCrop(512), - transforms.Resize((args.my_img_size)) - ]) - self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) + img_transform = transforms.Compose( + [transforms.CenterCrop(512), transforms.Resize((args.my_img_size))] + ) + self.data_raw = ( + wds.WebDataset(args.data_file, resampled=True) + .shuffle( + 10000, + initial=1000, + rng=random.Random(epoch * 100000 + rank + bias * 1e9), + ) + .decode("torchrgb") + .to_tuple("jpg", "json", "txt") + .map_tuple(img_transform, identity, identity) + ) for pp in self.data_raw.pipeline: - if 'Resampled' in str(pp): + if "Resampled" in str(pp): pp.deterministic = True + def worker_seed(): - return rank*100000+epoch+bias*1e9 + return rank * 100000 + epoch + bias * 1e9 + pp.worker_seed = worker_seed self.data = iter(self.data_raw) # print(f"WebDataset loaded for rank {rank} epoch {epoch}") + if self.data == None: init_wds(self) trial = 0 while trial < 10: try: - dd = next(self.data) # jpg, json, txt + dd = next(self.data) # jpg, json, txt break except: - print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]') + print( + f"[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]" + ) self.error_count += 1 init_wds(self, self.error_count) trial += 1 @@ -144,7 +186,7 @@ class MyDataset(Dataset): return dd[0], dd[2] else: if args.data_type == "uint16": - i = np.random.randint(0, self.data_size-1) + i = np.random.randint(0, self.data_size - 1) dix = self.data[i] x = torch.tensor(dix[:-1], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long) @@ -196,7 +238,12 @@ class MyDataset(Dataset): z_sum = 0 isGood = False for i in range(3, ctx_len): - if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187: + if ( + dix[i] == 27 + and dix[i - 1] == 34 + and dix[i - 2] == 187 + and dix[i - 3] == 187 + ): isGood = True if dix[i] == 0: isGood = False @@ -206,7 +253,9 @@ class MyDataset(Dataset): if z_sum == 0: z = [1] * ctx_len i = np.random.randint(0, self.data_pile_size - req_len) - dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int) + dix = self.data_pile.get( + idx=0, offset=i, length=req_len + ).astype(int) z = torch.tensor(z, dtype=torch.bfloat16) x = torch.tensor(dix[:-1], dtype=torch.long) diff --git a/finetune/lora/src/model.py b/finetune/lora/v4/src/model.py similarity index 71% rename from finetune/lora/src/model.py rename to finetune/lora/v4/src/model.py index 15f8d82..6ac4dc8 100644 --- a/finetune/lora/src/model.py +++ b/finetune/lora/v4/src/model.py @@ -5,6 +5,7 @@ import functools import os, math, gc, importlib import torch + # torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_mode(True) import torch.nn as nn @@ -13,7 +14,8 @@ from torch.nn import functional as F import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy -if importlib.util.find_spec('deepspeed'): + +if importlib.util.find_spec("deepspeed"): import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam @@ -28,9 +30,10 @@ LORA_CONFIG = { try: - print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"]) + print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"]) except: - os.environ["RWKV_MY_TESTING"] = '' + os.environ["RWKV_MY_TESTING"] = "" + def __nop(ob): return ob @@ -53,7 +56,26 @@ T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM! from torch.utils.cpp_extension import load if os.environ["RWKV_FLOAT_MODE"] == "bf16": - wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["finetune/lora/cuda/wkv_op_bf16.cpp", "finetune/lora/cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + wkv_cuda = load( + name=f"wkv_{T_MAX}_bf16", + sources=[ + "finetune/lora/v4/cuda/wkv_op_bf16.cpp", + "finetune/lora/v4/cuda/wkv_cuda_bf16.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-t 4", + "-std=c++17", + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={T_MAX}", + ], + ) + class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): @@ -66,10 +88,16 @@ if os.environ["RWKV_FLOAT_MODE"] == "bf16": u = u.contiguous() k = k.contiguous() v = v.contiguous() - y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + y = torch.empty( + (B, T, C), + device=w.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) wkv_cuda.forward(B, T, C, w, u, k, v, y) ctx.save_for_backward(w, u, k, v, y) return y + @staticmethod def backward(ctx, gy): B = ctx.B @@ -78,16 +106,54 @@ if os.environ["RWKV_FLOAT_MODE"] == "bf16": assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v, y = ctx.saved_tensors - gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + gw = torch.empty( + (B, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) + gu = torch.empty( + (B, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) + gk = torch.empty( + (B, T, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) + gv = torch.empty( + (B, T, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) return (None, None, None, gw, gu, gk, gv) + else: - wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["finetune/lora/cuda/wkv_op.cpp", "finetune/lora/cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + wkv_cuda = load( + name=f"wkv_{T_MAX}", + sources=[ + "finetune/lora/v4/cuda/wkv_op.cpp", + "finetune/lora/v4/cuda/wkv_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={T_MAX}", + ], + ) + class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): @@ -106,7 +172,9 @@ else: u = u.float().contiguous() k = k.float().contiguous() v = v.float().contiguous() - y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) + y = torch.empty( + (B, T, C), device=w.device, memory_format=torch.contiguous_format + ) wkv_cuda.forward(B, T, C, w, u, k, v, y) ctx.save_for_backward(w, u, k, v, y) if "32" in os.environ["RWKV_FLOAT_MODE"]: @@ -115,6 +183,7 @@ else: return y.half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": return y.bfloat16() + @staticmethod def backward(ctx, gy): B = ctx.B @@ -123,14 +192,26 @@ else: assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v, y = ctx.saved_tensors - gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) - gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) - gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) - gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) + gw = torch.empty( + (B, C), device=gy.device, memory_format=torch.contiguous_format + ) + gu = torch.empty( + (B, C), device=gy.device, memory_format=torch.contiguous_format + ) + gk = torch.empty( + (B, T, C), device=gy.device, memory_format=torch.contiguous_format + ) + gv = torch.empty( + (B, T, C), device=gy.device, memory_format=torch.contiguous_format + ) if "32" in os.environ["RWKV_FLOAT_MODE"]: - wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) + wkv_cuda.backward( + B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv + ) else: - wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv) + wkv_cuda.backward( + B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv + ) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) if "32" in os.environ["RWKV_FLOAT_MODE"]: @@ -138,7 +219,15 @@ else: elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) elif os.environ["RWKV_FLOAT_MODE"] == "bf16": - return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + return ( + None, + None, + None, + gw.bfloat16(), + gu.bfloat16(), + gk.bfloat16(), + gv.bfloat16(), + ) def RUN_CUDA(B, T, C, w, u, k, v): @@ -151,15 +240,17 @@ def RUN_CUDA(B, T, C, w, u, k, v): class LoraLinear(nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool): super().__init__() self.weight = nn.Parameter(torch.empty((out_features, in_features))) assert bias == False, "Biased LoraLinear not supported" - r, alpha, dropout = LORA_CONFIG["r"], LORA_CONFIG[ - "alpha"], LORA_CONFIG["dropout"] + r, alpha, dropout = ( + LORA_CONFIG["r"], + LORA_CONFIG["alpha"], + LORA_CONFIG["dropout"], + ) self.lora_A = nn.Parameter(torch.empty(r, in_features)) self.lora_B = nn.Parameter(torch.empty(out_features, r)) self.lora_dropout = nn.Dropout(dropout) @@ -170,9 +261,9 @@ class LoraLinear(nn.Module): nn.init.zeros_(self.lora_B) def forward(self, x): - return ( - F.linear(x, self.weight) + self.scaling * - F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B)) + return F.linear(x, self.weight) + self.scaling * F.linear( + F.linear(self.lora_dropout(x), self.lora_A), self.lora_B + ) @functools.wraps(LoraLinear) @@ -214,17 +305,23 @@ class RWKV_TimeMix(MyModule): # fancy time_decay decay_speed = torch.ones(args.dim_att) for h in range(args.dim_att): - decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) # fancy time_first zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5 - self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag) + self.time_first = nn.Parameter( + torch.ones(args.dim_att) * math.log(0.3) + zigzag + ) # fancy time_mix self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_v = nn.Parameter( + torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) @@ -235,8 +332,10 @@ class RWKV_TimeMix(MyModule): self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) - if 'a' in os.environ["RWKV_MY_TESTING"]: - self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + if "a" in os.environ["RWKV_MY_TESTING"]: + self.register_buffer( + "att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) d_qkv = args.n_embd // 16 self.qq = nn.Linear(args.n_embd, d_qkv, bias=False) self.kk = nn.Linear(args.n_embd, d_qkv, bias=False) @@ -245,12 +344,17 @@ class RWKV_TimeMix(MyModule): with torch.no_grad(): self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_vv = nn.Parameter( + torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) + + if "a" not in os.environ["RWKV_MY_TESTING"]: - if 'a' not in os.environ["RWKV_MY_TESTING"]: @MyFunction def jit_func(self, x): - xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift( + x + ) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) @@ -263,21 +367,26 @@ class RWKV_TimeMix(MyModule): def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v = self.jit_func(x) - rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) + rwkv = sr * RUN_CUDA( + B, T, self.args.dim_att, self.time_decay, self.time_first, k, v + ) return self.output(rwkv) - if 'a' in os.environ["RWKV_MY_TESTING"]: + if "a" in os.environ["RWKV_MY_TESTING"]: + @MyFunction def QKV(self, q, k, v): att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.att_mask == 0, float('-inf')) - att = F.softmax(att, dim = -1) + att = att.masked_fill(self.att_mask == 0, float("-inf")) + att = F.softmax(att, dim=-1) x = att @ v return x @MyFunction def jit_funcQKV(self, x): - xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift( + x + ) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) @@ -296,12 +405,16 @@ class RWKV_TimeMix(MyModule): def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v, qq, kk, vv = self.jit_funcQKV(x) - rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) + rwkv = sr * RUN_CUDA( + B, T, self.args.dim_att, self.time_decay, self.time_first, k, v + ) rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv)) return rwkv + ######################################################################################################## + class RWKV_ChannelMix(MyModule): def __init__(self, args, layer_id): super().__init__() @@ -331,6 +444,7 @@ class RWKV_ChannelMix(MyModule): kv = self.value(k) return torch.sigmoid(self.receptance(xr)) * kv + class MishGLU(MyModule): def __init__(self, args, layer_id): super().__init__() @@ -360,6 +474,7 @@ class MishGLU(MyModule): b = self.bb(xb) return self.value(a * F.mish(b)) + ######################################################################################################## # The RWKV Model with our blocks ######################################################################################################## @@ -377,15 +492,19 @@ class Block(nn.Module): if self.layer_id == 0: self.ln0 = nn.LayerNorm(args.n_embd) if args.my_pos_emb > 0: - self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd))) - self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd))) + self.pos_emb_x = nn.Parameter( + torch.zeros((1, args.my_pos_emb, args.n_embd)) + ) + self.pos_emb_y = nn.Parameter( + torch.zeros((args.my_pos_emb, 1, args.n_embd)) + ) if self.layer_id == 0 and self.args.pre_ffn > 0: self.ffnPre = RWKV_ChannelMix(args, 0) else: self.att = RWKV_TimeMix(args, layer_id) - if 'g' in os.environ["RWKV_MY_TESTING"]: + if "g" in os.environ["RWKV_MY_TESTING"]: self.ffn = MishGLU(args, layer_id) else: self.ffn = RWKV_ChannelMix(args, layer_id) @@ -395,7 +514,9 @@ class Block(nn.Module): self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) - self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + self.register_buffer( + "tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) def forward(self, x, x_emb=None): args = self.args @@ -403,7 +524,7 @@ class Block(nn.Module): if self.layer_id == 0: x = self.ln0(x) if args.my_pos_emb > 0: - pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :] x = x + pos_emb if self.layer_id == 0 and args.pre_ffn > 0: @@ -443,13 +564,13 @@ class RWKV(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args - if not hasattr(args, 'dim_att'): + if not hasattr(args, "dim_att"): args.dim_att = args.n_embd - if not hasattr(args, 'dim_ffn'): + if not hasattr(args, "dim_ffn"): args.dim_ffn = args.n_embd * 4 - if not hasattr(args, 'tiny_att_layer'): + if not hasattr(args, "tiny_att_layer"): args.tiny_att_layer = -1 - if not hasattr(args, 'tiny_att_dim'): + if not hasattr(args, "tiny_att_dim"): args.tiny_att_dim = -1 self.emb = nn.Embedding(args.vocab_size, args.n_embd) @@ -462,7 +583,9 @@ class RWKV(pl.LightningModule): if args.head_qk > 0: self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False) self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) - self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + self.register_buffer( + "copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) def configure_optimizers(self): args = self.args @@ -494,19 +617,46 @@ class RWKV(pl.LightningModule): param_dict = {n: p for n, p in self.named_parameters()} if args.my_pile_stage == 2: optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 2e-3 / args.lr_init}, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 3e-3 / args.lr_init}, ] else: optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 2.0, + }, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 3.0, + }, ] else: optim_groups = [ - {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + { + "params": [p for n, p in self.named_parameters()], + "weight_decay": 0.0, + }, ] for g in optim_groups: @@ -514,8 +664,26 @@ class RWKV(pl.LightningModule): optim_groups = [g for g in optim_groups if len(g["params"]) > 0] if self.deepspeed_offload: - return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) - return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + return DeepSpeedCPUAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adamw_mode=False, + weight_decay=0, + amsgrad=False, + ) + return FusedAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adam_w_mode=False, + weight_decay=0, + amsgrad=False, + ) # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) @property @@ -589,10 +757,14 @@ class RWKV(pl.LightningModule): logits = self(idx) if sum_mask == mask.shape[0]: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1) + ) # print('rank', self.global_rank, 'loss', loss.item()) else: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none" + ) # loss_raw = loss loss = torch.sum(loss * mask) / sum_mask @@ -632,7 +804,14 @@ class RWKV(pl.LightningModule): gain = 1.0 scale = 1.0 - if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n: + if ( + "ln_" in n + or ".ln" in n + or "time_" in n + or "_mask" in n + or "pos_emb" in n + or ".mask." in n + ): m[n] = p else: if n == "emb.weight": @@ -640,7 +819,19 @@ class RWKV(pl.LightningModule): else: if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) - for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']: + for kk in [ + ".att.key.", + ".att.receptance.", + ".att.output.", + ".att.key.", + ".ffn.value.", + ".ffn.receptance.", + ".ffnPre.value.", + ".ffnPre.receptance.", + "head_q.", + ".oo.", + ".rr.", + ]: if kk in n: scale = 0 if n == "head.weight": @@ -650,7 +841,9 @@ class RWKV(pl.LightningModule): if "head_q." in n: scale = 0 - print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") + print( + f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}" + ) if self.args.accelerator.upper() == "GPU": m[n] = torch.empty((shape[0], shape[1]), device="cuda") diff --git a/finetune/lora/src/trainer.py b/finetune/lora/v4/src/trainer.py similarity index 78% rename from finetune/lora/src/trainer.py rename to finetune/lora/v4/src/trainer.py index ab65776..d78207c 100644 --- a/finetune/lora/src/trainer.py +++ b/finetune/lora/v4/src/trainer.py @@ -5,15 +5,17 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from .model import LORA_CONFIG + def my_save(dd, ff): - if '14b-run1' not in ff: + if "14b-run1" not in ff: torch.save(dd, ff) else: - fn = ff.split('/')[-1] - fff = '/dev/shm/' + fn + fn = ff.split("/")[-1] + fff = "/dev/shm/" + fn torch.save(dd, fff) subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True) + class train_callback(pl.Callback): def __init__(self, args): super().__init__() @@ -38,7 +40,9 @@ class train_callback(pl.Callback): if args.lr_final == 0 or args.lr_init == 0: # linear decay lr = args.lr_init + (args.lr_final - args.lr_init) * progress else: # exp decay - lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + lr = args.lr_init * math.exp( + math.log(args.lr_final / args.lr_init) * pow(progress, 1) + ) if trainer.global_step < w_step: lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) @@ -60,7 +64,9 @@ class train_callback(pl.Callback): trainer.my_loss_sum = 0 trainer.my_loss_count = 0 trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") - trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n") + trainer.my_log.write( + f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n" + ) try: print(f"\n{trainer.strategy.config}\n") trainer.my_log.write(f"{trainer.strategy.config}\n") @@ -70,6 +76,7 @@ class train_callback(pl.Callback): if len(args.wandb) > 0: print("Login to wandb...") import wandb + wandb.init( project=args.wandb, name=args.run_name + " " + args.my_timestamp, @@ -102,20 +109,26 @@ class train_callback(pl.Callback): # self.log("s", real_step, prog_bar=True, on_step=True) if len(args.wandb) > 0: - lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9} + lll = { + "loss": trainer.my_loss, + "lr": trainer.my_lr, + "Gtokens": real_step * token_per_step / 1e9, + } if kt_s > 0: lll["kt/s"] = kt_s trainer.my_wandb.log(lll, step=int(real_step)) if args.magic_prime > 0: expand_factor = 2 if args.my_qa_mask > 0 else 1 - if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1: + if ( + int(real_step) + == int(args.magic_prime * expand_factor // args.real_bsz) - 1 + ): to_save_dict = pl_module.state_dict() my_save( to_save_dict, f"{args.proj_dir}/rwkv-final.pth", ) - def on_train_epoch_start(self, trainer, pl_module): args = self.args dataset = trainer.train_dataloader.dataset.datasets @@ -128,24 +141,28 @@ class train_callback(pl.Callback): def on_train_epoch_end(self, trainer, pl_module): args = self.args if trainer.is_global_zero: # logging & save state_dict - if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: - if args.data_type == 'wds_img': + if ( + args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0 + ) or trainer.current_epoch == args.epoch_count - 1: + if args.data_type == "wds_img": raw_dict = pl_module.state_dict() to_save_dict = {} for k in raw_dict: - if k.startswith('encoder.') or k.startswith('decoder.'): + if k.startswith("encoder.") or k.startswith("decoder."): to_save_dict[k] = raw_dict[k] else: to_save_dict = pl_module.state_dict() if args.lora: - enable_time_finetune = 'time' in LORA_CONFIG["parts"] - enable_ln_finetune = 'ln' in LORA_CONFIG["parts"] + enable_time_finetune = "time" in LORA_CONFIG["parts"] + enable_ln_finetune = "ln" in LORA_CONFIG["parts"] lora_dict = {} for name, state in to_save_dict.items(): - if ('.lora_' in name - or (enable_time_finetune and '.time_' in name) - or (enable_ln_finetune and '.ln' in name)): + if ( + ".lora_" in name + or (enable_time_finetune and ".time_" in name) + or (enable_ln_finetune and ".ln" in name) + ): lora_dict[name] = state to_save_dict = lora_dict @@ -155,8 +172,10 @@ class train_callback(pl.Callback): f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", ) except Exception as e: - print('Error\n\n', e, '\n\n') - trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") + print("Error\n\n", e, "\n\n") + trainer.my_log.write( + f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n" + ) trainer.my_log.flush() trainer.my_loss_sum = 0 @@ -178,22 +197,22 @@ def generate_init_weight(model, init_weight_name): mm[k] = src.reshape(mm[k].shape) except: tmp = mm[k].squeeze().clone() - print(k, src.shape, '-->', mm[k].shape) + print(k, src.shape, "-->", mm[k].shape) ss = src.shape[0] dd = tmp.shape[0] for i in range(dd): pos = i / dd * ss if pos >= ss - 1: - tmp[i] = src[ss-1] + tmp[i] = src[ss - 1] else: p0 = int(math.floor(pos)) ii = pos - p0 - tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii) + tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii) mm[k] = tmp.reshape(mm[k].shape) sss = src.squeeze().float().cpu().numpy() - print(sss[:10], '...', sss[-10:]) + print(sss[:10], "...", sss[-10:]) mmm = mm[k].squeeze().float().cpu().numpy() - print(mmm[:10], '...', mmm[-10:]) + print(mmm[:10], "...", mmm[-10:]) print(f"Save to {init_weight_name}...") torch.save(mm, init_weight_name) diff --git a/finetune/lora/src/utils.py b/finetune/lora/v4/src/utils.py similarity index 84% rename from finetune/lora/src/utils.py rename to finetune/lora/v4/src/utils.py index ea25990..87da098 100644 --- a/finetune/lora/src/utils.py +++ b/finetune/lora/v4/src/utils.py @@ -6,6 +6,7 @@ from torch.nn import functional as F time_slot = {} time_ref = time.time_ns() + def record_time(name): if name not in time_slot: time_slot[name] = 1e20 @@ -13,20 +14,23 @@ def record_time(name): if tt < time_slot[name]: time_slot[name] = tt -class TOKENIZER(): - def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): - if 'list' in str(type(WORD_NAME)): + +class TOKENIZER: + def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"): + if "list" in str(type(WORD_NAME)): self.charMode = False if WORD_NAME[0] == WORD_NAME[1]: from transformers import PreTrainedTokenizerFast + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) else: from transformers import GPT2TokenizerFast + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) self.vocab_size = len(self.tokenizer) else: self.charMode = True - with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: + with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file: self.word_table = json.load(result_file) self.vocab_size = len(self.word_table) @@ -37,23 +41,25 @@ class TOKENIZER(): self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] def refine_context(self, context): - context = context.strip().split('\n') + context = context.strip().split("\n") for c in range(len(context)): - context[c] = context[c].strip().strip('\u3000').strip('\r') - context = list(filter(lambda c: c != '', context)) - context = '\n' + ('\n'.join(context)).strip() - if context == '': - context = '\n' + context[c] = context[c].strip().strip("\u3000").strip("\r") + context = list(filter(lambda c: c != "", context)) + context = "\n" + ("\n".join(context)).strip() + if context == "": + context = "\n" return context - def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + def sample_logits( + self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None + ): # out[self.UNKNOWN_CHAR] = -float('Inf') lastChar = int(x[-1]) probs = F.softmax(out, dim=-1) if self.charMode: - if self.itos[lastChar] == '\n': + if self.itos[lastChar] == "\n": top_p = top_p_newline else: top_p = top_p_usual @@ -81,6 +87,7 @@ class TOKENIZER(): out = torch.multinomial(probs, num_samples=1)[0] return out + def MaybeIsPrime(number): if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): return True @@ -121,7 +128,9 @@ def MillerRabinPrimalityTest(number): if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): iterationNumber = 1 - while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1): + while (iterationNumber <= timesTwoDividNumber - 1) and ( + randomNumberWithPower != number - 1 + ): randomNumberWithPower = pow(randomNumberWithPower, 2, number) iterationNumber = iterationNumber + 1 if randomNumberWithPower != (number - 1): diff --git a/finetune/lora/train.py b/finetune/lora/v4/train.py similarity index 98% rename from finetune/lora/train.py rename to finetune/lora/v4/train.py index 78f2646..108d3bb 100644 --- a/finetune/lora/train.py +++ b/finetune/lora/v4/train.py @@ -184,7 +184,7 @@ if __name__ == "__main__": args.num_sanity_val_steps = 0 args.check_val_every_n_epoch = int(1e20) args.log_every_n_steps = int(1e20) - args.max_epochs = args.epoch_count # continue forever + args.max_epochs = args.epoch_count # -1 continue forever args.betas = (args.beta1, args.beta2) args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz os.environ["RWKV_T_MAX"] = str(args.ctx_len) @@ -373,7 +373,7 @@ if __name__ == "__main__": for param in module.parameters(): param.requires_grad = True elif enable_time_finetune and any( - n.startswith("time") for n, _ in module.named_parameters() + n.startswith("time") for n, _ in module.named_parameters() ): for pname, param in module.named_parameters(): if pname.startswith("time"): @@ -381,7 +381,7 @@ if __name__ == "__main__": param.requires_grad = True if ( - len(args.load_model) == 0 or args.my_pile_stage == 1 + len(args.load_model) == 0 or args.my_pile_stage == 1 ): # shall we build the initial weights? init_weight_name = f"{args.proj_dir}/rwkv-init.pth" generate_init_weight(model, init_weight_name) # save initial weights @@ -423,8 +423,8 @@ if __name__ == "__main__": ) if ( - args.lr_init > 1e-4 - or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8 + args.lr_init > 1e-4 + or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8 ): if "I_KNOW_WHAT_IM_DOING" in os.environ: if trainer.global_rank == 0: @@ -459,10 +459,10 @@ if __name__ == "__main__": if "deepspeed" in args.strategy: trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = ( - args.ds_bucket_mb * 1000 * 1000 + args.ds_bucket_mb * 1000 * 1000 ) trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = ( - args.ds_bucket_mb * 1000 * 1000 + args.ds_bucket_mb * 1000 * 1000 ) # must set shuffle=False, persistent_workers=False (because worker is in another thread) diff --git a/finetune/lora/v5/cuda/wkv5_cuda.cu b/finetune/lora/v5/cuda/wkv5_cuda.cu new file mode 100644 index 0000000..3e6b859 --- /dev/null +++ b/finetune/lora/v5/cuda/wkv5_cuda.cu @@ -0,0 +1,202 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _w += h*_N_; + _u += h*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_] = {0}; + + __syncthreads(); + w[i] = _w[i]; + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _w += h*_N_; + _u += h*_N_; + __w += h*_N_; + + __shared__ float w_[_N_], u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_]; + __syncthreads(); + w_[i] = _w[i]; + u_[i] = float(_u[i]); + __syncthreads(); + + const float w = w_[i]; + const float ww = __w[i]; + const float u = u_[i]; + + float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; + + float gw = 0, gu = 0; + const int t000 = b*T*C + h*_N_ + i; + const int t111 = (b+1)*T*C + h*_N_ + i; + const int t222 = t111 - 2*C; + + for (int t = t000; t < t111; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t000; t < t222; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t + 2*C]); + __syncthreads(); + + const float k = float(_k[t]); + float gw_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + float& s2 = sbbbb[j]; + float x = k * v[j]; + + float tmp = w * (x + s); + s = tmp; + s2 = tmp + w * s2; + gw_ += s2 * gy[j]; + } + gw += float(_r[t + 2*C]) * gw_; + } + _gw[b*C + h*_N_ + i] = F(ww * gw); + + for (int t = t111 - C; t >= t000; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t111 - C; t >= t000; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); +} diff --git a/finetune/lora/v5/cuda/wkv5_op.cpp b/finetune/lora/v5/cuda/wkv5_op.cpp new file mode 100644 index 0000000..4c9ece1 --- /dev/null +++ b/finetune/lora/v5/cuda/wkv5_op.cpp @@ -0,0 +1,22 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv5 forward"); + m.def("backward", &backward, "wkv5 backward"); +} + +TORCH_LIBRARY(wkv5, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/finetune/lora/v5/src/__init__.py b/finetune/lora/v5/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/finetune/lora/v5/src/binidx.py b/finetune/lora/v5/src/binidx.py new file mode 100644 index 0000000..c2d60a1 --- /dev/null +++ b/finetune/lora/v5/src/binidx.py @@ -0,0 +1,303 @@ +from lib2to3.pgen2 import token +import os +import torch +import numpy as np +import shutil +import struct +from functools import lru_cache +from itertools import accumulate + + +def print_rank_0(*message): + pass + # """If distributed is initialized print only on rank 0.""" + # if torch.distributed.is_initialized(): + # if torch.distributed.get_rank() == 0: + # print(*message, flush=True) + # else: + # print(*message, flush=True) + + +def _warmup_mmap_file(path): + pass + # with open(path, "rb") as stream: + # while stream.read(100 * 1024 * 1024): + # pass + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: float, + 7: np.double, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +class MMapIndexedDataset(torch.utils.data.Dataset): + class Index(object): + _HDR_MAGIC = b"MMIDIDX\x00\x00" + + @classmethod + def writer(cls, path, dtype): + class _Writer(object): + def __enter__(self): + self._file = open(path, "wb") + + # Write Magic string so we can check the file format then opening it again. + self._file.write(cls._HDR_MAGIC) + # Write version number + # Little endian unsigned 64 Bit integer + self._file.write(struct.pack(" 0: + # self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document') + self.data_pile = MMapIndexedDataset( + "/fsx/pile_deduped/pile_0.87_deduped_text_document" + ) + self.data_pile_size = ( + len(self.data_pile._bin_buffer) // self.data._index._dtype_size + ) + else: + self.data_pile = None + self.data_pile_size = 0 + + if args.my_pile_stage > 0: + # assert self.data_size == 332115325534 and self.vocab_size == 50277 + self.samples_per_epoch = args.epoch_steps * args.real_bsz + assert self.samples_per_epoch == 40320 + rank_zero_info( + f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########" + ) + dataset_slot = self.data_size // args.ctx_len + if args.my_pile_stage != 4: + assert MaybeIsPrime(args.magic_prime) + assert args.magic_prime % 3 == 2 + assert ( + args.magic_prime / dataset_slot > 0.99 + and args.magic_prime / dataset_slot <= 1 + ) + elif args.data_type == "numpy": + self.data = np.load(args.data_file).astype("int") + self.vocab_size = args.vocab_size + rank_zero_info( + f"Current vocab size = {self.vocab_size} (make sure it's correct)" + ) + self.data_size = len(self.data) + rank_zero_info(f"Data has {self.data_size} tokens.") + elif args.data_type == "uint16": + self.data = ( + np.fromfile(args.data_file, dtype=np.uint16) + .astype("int32") + .reshape(-1, args.my_sample_len) + ) + self.vocab_size = args.vocab_size + rank_zero_info( + f"Current vocab size = {self.vocab_size} (make sure it's correct)" + ) + self.data_size = self.data.shape[0] + rank_zero_info(f"Data has {self.data_size} samples.") + else: + if args.data_type == "dummy": + rank_zero_info("Building dummy data...") + self.data = "" + for i in range(100000): + aa = (i) % 10000 + bb = (i * i) % 10000 + cc = aa + bb + self.data += f".{aa}+{bb}={cc}." + else: + self.data = open(args.data_file, "r", encoding=args.data_type).read() + rank_zero_info("Building token list...") + unique = sorted(list(set(self.data))) + self.vocab_size = len(unique) + # rank_zero_info() + # for u in unique: + # print(u, end=' ') + # rank_zero_info('\n\n') + xx = 0 + xxObj = {} + for u in unique: + xxObj[xx] = u + xx += 1 + with open( + f"{args.proj_dir}/vocab.json", "w", encoding="utf-8" + ) as vocab_file: + vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) + self.data_size = len(self.data) + rank_zero_info( + f"Data has {self.data_size} tokens, {self.vocab_size} vocab size." + ) + self.stoi = {ch: i for i, ch in enumerate(unique)} + self.itos = {i: ch for i, ch in enumerate(unique)} + + def __len__(self): + return self.args.epoch_steps * self.args.micro_bsz + + def __getitem__(self, idx): + args = self.args + rank = self.global_rank + epoch = self.real_epoch + world_size = self.world_size + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") + + if args.data_type == "uint16": + i = np.random.randint(0, self.data_size - 1) + dix = self.data[i] + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + else: + ctx_len = args.ctx_len + req_len = ctx_len + 1 + magic_prime = args.magic_prime + data = self.data + + if args.my_pile_stage > 0: + ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank + + if args.my_qa_mask > 0: + ii_orig = ii + if ii % 2 == 0: + ii = -1 + data = self.data_pile + else: + ii = ii // 2 + if data == self.data_pile: + i = np.random.randint(0, self.data_pile_size - req_len) + else: + if args.my_pile_stage == 4 or ii < args.my_random_steps: + # cheat: pick a random spot in dataset + if args.my_pile_version == 1: + i = np.random.randint(0, self.data_size - req_len) + else: + i = np.random.randint(0, self.data_size) + else: + ii = ii - args.my_random_steps + factor = (math.sqrt(5) - 1) / 2 + factor = int(magic_prime * factor) + i = ((factor * ii * ii * ii) % magic_prime) * ctx_len + i = i + args.my_pile_shift + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") + else: + # cheat: pick a random spot in dataset + i = np.random.randint(0, self.data_size - req_len) + + if args.data_type == "binidx": + if args.my_pile_version == 1: + dix = data.get(idx=0, offset=i, length=req_len).astype(int) + else: + # self.data : cutoff, chunk_count, data + for j in range(len(data)): + if i < data[j][0]: + ii = i + i = (i - (data[j - 1][0] if j > 0 else 0)) % data[j][1] + dix = ( + data[j][2] + .get(idx=0, offset=i, length=req_len) + .astype(int) + ) + # print(ii, j, i) + break + elif args.data_type == "numpy": + dix = data[i : i + req_len] + else: + dix = [self.stoi[s] for s in data[i : i + req_len]] + + if args.my_qa_mask == 1: + if data == self.data_pile: + z = [1] * ctx_len + else: + z = [0] * ctx_len + z_sum = 0 + isGood = False + for i in range(3, ctx_len): + if ( + dix[i] == 27 + and dix[i - 1] == 34 + and dix[i - 2] == 187 + and dix[i - 3] == 187 + ): + isGood = True + if dix[i] == 0: + isGood = False + if isGood: + z[i] = 1 + z_sum += 1 + if z_sum == 0: + z = [1] * ctx_len + i = np.random.randint(0, self.data_pile_size - req_len) + dix = self.data_pile.get( + idx=0, offset=i, length=req_len + ).astype(int) + z = torch.tensor(z, dtype=torch.bfloat16) + + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + + # if ii_orig < 50: + # # if rank == 1: + # print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:]) + # else: + # exit(0) + + if args.my_qa_mask == 1: + return x, y, z + + return x, y diff --git a/finetune/lora/v5/src/model.py b/finetune/lora/v5/src/model.py new file mode 100644 index 0000000..d961293 --- /dev/null +++ b/finetune/lora/v5/src/model.py @@ -0,0 +1,819 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## +import functools +import os, math, gc, importlib +import torch + +# torch._C._jit_set_profiling_executor(True) +# torch._C._jit_set_profiling_mode(True) +import torch.nn as nn +from torch.utils.checkpoint import checkpoint as torch_checkpoint +from torch.nn import functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from pytorch_lightning.strategies import DeepSpeedStrategy + +if importlib.util.find_spec("deepspeed"): + import deepspeed + from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam + + +# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam + +# lora-config +LORA_CONFIG = { + "r": 0, + "alpha": 0, + "dropout": 0, + "parts": {"att", "ln", "time"}, +} + +try: + print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"]) +except: + os.environ["RWKV_MY_TESTING"] = "" + + +def __nop(ob): + return ob + + +MyModule = nn.Module +MyFunction = __nop +if os.environ["RWKV_JIT_ON"] == "1": + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +from torch.utils.cpp_extension import load + +HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"]) +wkv5_cuda = load( + name="wkv5", + sources=[ + "finetune/lora/v5/cuda/wkv5_op.cpp", + f"finetune/lora/v5/cuda/wkv5_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + ], +) + + +class WKV_5(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, r, k, v, w, u): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + ew = (-torch.exp(w.float())).contiguous() + eew = (torch.exp(ew)).contiguous() + ctx.save_for_backward(r, k, v, eew, ew, u) + y = torch.empty( + (B, T, C), + device=r.device, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, eew, ew, u = ctx.saved_tensors + gr = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gk = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gv = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gw = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gu = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu) + gw = torch.sum(gw, 0).view(H, C // H) + gu = torch.sum(gu, 0).view(H, C // H) + return (None, None, None, None, gr, gk, gv, gw, gu) + + +def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u): + return WKV_5.apply(B, T, C, H, r, k, v, w, u) + + +################################################################# +class LoraLinear(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool): + super().__init__() + + self.weight = nn.Parameter(torch.empty((out_features, in_features))) + assert bias == False, "Biased LoraLinear not supported" + + r, alpha, dropout = ( + LORA_CONFIG["r"], + LORA_CONFIG["alpha"], + LORA_CONFIG["dropout"], + ) + self.lora_A = nn.Parameter(torch.empty(r, in_features)) + self.lora_B = nn.Parameter(torch.empty(out_features, r)) + self.lora_dropout = nn.Dropout(dropout) + self.scaling = alpha / r + + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def forward(self, x): + return F.linear(x, self.weight) + self.scaling * F.linear( + F.linear(self.lora_dropout(x), self.lora_A), self.lora_B + ) + + +@functools.wraps(LoraLinear) +def make_linear_att(*args, **kwargs): + if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: + return LoraLinear(*args, **kwargs) + else: + return nn.Linear(*args, **kwargs) + + +@functools.wraps(LoraLinear) +def make_linear_ffn(*args, **kwargs): + if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: + return LoraLinear(*args, **kwargs) + else: + return nn.Linear(*args, **kwargs) + + +######################################################################################################## + + +class RWKV_TimeMix_RWKV5(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + + self.head_size = args.head_size_a + assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a + self.n_head = args.dim_att // self.head_size + assert args.dim_att % self.n_head == 0 + self.head_size_divisor = args.head_size_divisor + + with torch.no_grad(): + ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + + # fancy time_mix + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_v = nn.Parameter( + torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) + self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + + # fancy time_decay + decay_speed = torch.ones(args.dim_att) + for n in range(args.dim_att): + decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) + self.time_decay = nn.Parameter( + decay_speed.reshape(self.n_head, self.head_size) + ) + # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) + + tmp = torch.zeros(args.dim_att) + for n in range(args.dim_att): + zigzag = ((n + 1) % 3 - 1) * 0.1 + tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag + + self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.key = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.value = make_linear_att(args.n_embd, args.dim_att, bias=False) + + self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) + self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.ln_x = nn.GroupNorm(self.n_head, args.dim_att) + + @MyFunction + def jit_func(self, x): + B, T, C = x.size() + + xx = self.time_shift( + x + ) # Mix x with the previous timestep to produce xk, xv, xr + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + xg = x * self.time_mix_g + xx * (1 - self.time_mix_g) + + r = self.receptance(xr) + k = self.key(xk) + v = self.value(xv) + g = F.silu(self.gate(xg)) + + return r, k, v, g + + @MyFunction + def jit_func_2(self, x, g): + B, T, C = x.size() + x = x.view(B * T, C) + x = self.ln_x(x / self.head_size_divisor).view(B, T, C) + x = self.output(x * g) + return x + + def forward(self, x): + B, T, C = x.size() + H = self.n_head + r, k, v, g = self.jit_func(x) + x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa) + + return self.jit_func_2(x, g) + + +######################################################################################################## + + +class RWKV_ChannelMix(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + + self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False) + self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False) + self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + k = self.key(xk) + k = torch.relu(k) ** 2 + kv = self.value(k) + return torch.sigmoid(self.receptance(xr)) * kv + + +class MishGLU(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) + + x = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + x[0, 0, i] = i / args.n_embd + + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) + xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) + a = self.aa(xa) + b = self.bb(xb) + return self.value(a * F.mish(b)) + + +######################################################################################################## +# The RWKV Model with our blocks +######################################################################################################## + + +class Block(nn.Module): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + + self.ln1 = nn.LayerNorm(args.n_embd) + self.ln2 = nn.LayerNorm(args.n_embd) + + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(args.n_embd) + if args.my_pos_emb > 0: + self.pos_emb_x = nn.Parameter( + torch.zeros((1, args.my_pos_emb, args.n_embd)) + ) + self.pos_emb_y = nn.Parameter( + torch.zeros((args.my_pos_emb, 1, args.n_embd)) + ) + + if self.layer_id == 0 and self.args.pre_ffn > 0: + self.ffnPre = RWKV_ChannelMix(args, 0) + else: + self.att = RWKV_TimeMix_RWKV5(args, layer_id) + + if "g" in os.environ["RWKV_MY_TESTING"]: + self.ffn = MishGLU(args, layer_id) + else: + self.ffn = RWKV_ChannelMix(args, layer_id) + + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + self.tiny_ln = nn.LayerNorm(args.n_embd) + self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) + self.register_buffer( + "tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) + + if args.dropout > 0: + self.drop0 = nn.Dropout(p=args.dropout) + self.drop1 = nn.Dropout(p=args.dropout) + + def forward(self, x, x_emb=None): + args = self.args + B, T, C = x.size() + if self.layer_id == 0: + x = self.ln0(x) + if args.my_pos_emb > 0: + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :] + x = x + pos_emb + + if self.args.dropout == 0: + if self.layer_id == 0 and args.pre_ffn > 0: + x = x + self.ffnPre(self.ln1(x)) + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + else: + if self.layer_id == 0 and args.pre_ffn > 0: + x = self.drop0(x + self.ffnPre(self.ln1(x))) + else: + x = self.drop0(x + self.att(self.ln1(x))) + x = self.drop1(x + self.ffn(self.ln2(x))) + + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + xx = self.tiny_ln(x) + q = self.tiny_q(xx)[:, :T, :] + k = self.tiny_k(xx)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) + c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) + x = x + c @ self.tiny_v(x_emb) + return x + + +class L2Wrap(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, y): + ctx.save_for_backward(y) + return loss + + @staticmethod + def backward(ctx, grad_output): + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + factor = 1e-4 / (y.shape[0] * y.shape[1]) + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy) + + +class RWKV(pl.LightningModule): + def __init__(self, args): + super().__init__() + self.args = args + if not hasattr(args, "dim_att"): + args.dim_att = args.n_embd + if not hasattr(args, "dim_ffn"): + args.dim_ffn = args.n_embd * 4 + if not hasattr(args, "tiny_att_layer"): + args.tiny_att_layer = -1 + if not hasattr(args, "tiny_att_dim"): + args.tiny_att_dim = -1 + assert args.n_embd % 32 == 0 + assert args.dim_att % 32 == 0 + assert args.dim_ffn % 32 == 0 + + self.emb = nn.Embedding(args.vocab_size, args.n_embd) + + self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) + + self.ln_out = nn.LayerNorm(args.n_embd) + self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) + + if args.head_qk > 0: + self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False) + self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) + self.register_buffer( + "copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) + if args.dropout > 0: + self.drop0 = nn.Dropout(p=args.dropout) + + def configure_optimizers(self): + args = self.args + + lr_decay = set() + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if ("time_mix" in n) and (args.layerwise_lr > 0): + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif ("time_decay" in n) and (args.layerwise_lr > 0): + if args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) + elif ("time_faaaa" in n) and (args.layerwise_lr > 0): + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif ("time_first" in n) and (args.layerwise_lr > 0): + lr_3x.add(n) + elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0): + lr_decay.add(n) + else: + lr_1x.add(n) + + lr_decay = sorted(list(lr_decay)) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + # print('decay', lr_decay) + # print('1x', lr_1x) + # print('2x', lr_2x) + # print('3x', lr_3x) + param_dict = {n: p for n, p in self.named_parameters()} + + if args.layerwise_lr > 0: + if args.my_pile_stage == 2: + optim_groups = [ + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 2e-3 / args.lr_init}, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 3e-3 / args.lr_init}, + ] + else: + optim_groups = [ + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 2.0, + }, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 3.0, + }, + ] + else: + optim_groups = [ + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + } + ] + + if args.weight_decay > 0: + optim_groups += [ + { + "params": [param_dict[n] for n in lr_decay], + "weight_decay": args.weight_decay, + "my_lr_scale": 1.0, + } + ] + if self.deepspeed_offload: + return DeepSpeedCPUAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adamw_mode=True, + amsgrad=False, + ) + return FusedAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adam_w_mode=True, + amsgrad=False, + ) + else: + if self.deepspeed_offload: + return DeepSpeedCPUAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adamw_mode=False, + weight_decay=0, + amsgrad=False, + ) + return FusedAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adam_w_mode=False, + weight_decay=0, + amsgrad=False, + ) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + cfg = strategy.config["zero_optimization"] + return cfg.get("offload_optimizer") or cfg.get("offload_param") + return False + + def forward(self, idx): + args = self.args + B, T = idx.size() + assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." + + x = self.emb(idx) + x_emb = x + + if args.dropout > 0: + x = self.drop0(x) + if args.tiny_att_dim > 0: + for block in self.blocks: + if args.grad_cp == 1: + if args.lora: + x = torch_checkpoint(block, x, x_emb, use_reentrant=False) + else: + x = deepspeed.checkpointing.checkpoint(block, x, x_emb) + else: + x = block(x, x_emb) + else: + for block in self.blocks: + if args.grad_cp == 1: + if args.lora: + x = torch_checkpoint(block, x, x_emb, use_reentrant=False) + else: + x = deepspeed.checkpointing.checkpoint(block, x) + else: + x = block(x) + + x = self.ln_out(x) + + if args.head_qk > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + if "32" in os.environ["RWKV_FLOAT_MODE"]: + c = c @ F.one_hot(idx, num_classes=args.vocab_size) + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() + + x = self.head(x) + c + else: + x = self.head(x) + + return x + + def training_step(self, batch, batch_idx): + args = self.args + if args.my_qa_mask != 1: + idx, targets = batch + logits = self(idx) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + # if '0' in os.environ["RWKV_MY_TESTING"]: + # print('logits', logits) + # torch.set_printoptions(threshold=10000) + # print('idx', idx) + # exit(0) + else: + idx, targets, mask = batch + mask = mask.view(-1) + sum_mask = torch.sum(mask).item() + # if sum_mask == 0: + # return torch.tensor([0.0], requires_grad=True) + + logits = self(idx) + if sum_mask == mask.shape[0]: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1) + ) + # print('rank', self.global_rank, 'loss', loss.item()) + else: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none" + ) + # loss_raw = loss + loss = torch.sum(loss * mask) / sum_mask + + # torch.set_printoptions(threshold=10000) + # if True: #self.global_rank == 1: + # tmp = '' + # sss = 0 + # ccc = 0 + # for i in range(mask.shape[0]): + # if mask[i] > 0: + # tmp += str(idx.view(-1)[i].item()) + ',' + # sss += loss_raw.view(-1)[i].float().item() + # ccc += 1 + # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) + return L2Wrap.apply(loss, logits) + + def training_step_end(self, batch_parts): + if pl.__version__[0] != "2": + all = self.all_gather(batch_parts) + if self.trainer.is_global_zero: + self.trainer.my_loss_all = all + + def generate_init_weight(self): + print( + f""" +############################################################################ +# +# Init model weight (slow for large models)... +# +############################################################################ +""" + ) + m = {} + for n in self.state_dict(): + p = self.state_dict()[n] + shape = p.shape + + gain = 1.0 + scale = 1.0 + if ( + "ln_" in n + or ".ln" in n + or "time_" in n + or "_mask" in n + or "pos_emb" in n + or ".mask." in n + ): + if "ln_x.weight" in n: + layer_scale = (1 + int(n.split(".")[1])) / self.args.n_layer + m[n] = (p * 0.0) + (layer_scale**0.7) + else: + m[n] = p + else: + if n == "emb.weight": + scale = -1 * self.args.lr_init + else: + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + + zero = [ + ".att.output.", + ".ffn.value.", + ".ffn.receptance.", + ".ffnPre.value.", + ".ffnPre.receptance.", + "head_q.", + ".oo.", + ".rr.", + ] + + for kk in zero: + if kk in n: + scale = 0 + if n == "head.weight": + scale = 0.5 + if "head_k." in n: + scale = 0.1 + if "head_q." in n: + scale = 0 + + print( + f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}" + ) + + if self.args.accelerator.upper() == "GPU": + m[n] = torch.empty((shape[0], shape[1]), device="cuda") + else: + m[n] = torch.empty((shape[0], shape[1])) + + if scale == 0: + nn.init.zeros_(m[n]) + elif scale < 0: + nn.init.uniform_(m[n], a=scale, b=-scale) + else: + nn.init.orthogonal_(m[n], gain=gain * scale) + + m[n] = m[n].cpu() + if os.environ["RWKV_FLOAT_MODE"] == "fp16": + m[n] = m[n].half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + m[n] = m[n].bfloat16() + + # if n == "emb.weight": + # print(m[n]) + + gc.collect() + torch.cuda.empty_cache() + return m diff --git a/finetune/lora/v5/src/trainer.py b/finetune/lora/v5/src/trainer.py new file mode 100644 index 0000000..e14e7fc --- /dev/null +++ b/finetune/lora/v5/src/trainer.py @@ -0,0 +1,310 @@ +import os, math, time, datetime, subprocess +import torch +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from .model import LORA_CONFIG + + +def my_save(args, trainer, dd, ff): + if "14b-run1" in ff: + fn = ff.split("/")[-1] + fff = "/dev/shm/" + fn + torch.save(dd, fff) + subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True) + elif ("world/14b" in ff) or ("world/7b" in ff): + aa = ff.split("/")[1] + fn = ff.split("/")[-1] + fff = f"/dev/shm/{aa}-{fn}" + torch.save(dd, fff) + subprocess.Popen( + f" aws s3 mv {fff} s3://rwkv-world/{aa}-{fn} --quiet", shell=True + ) + else: + if "deepspeed_stage_3" in args.strategy: + trainer.save_checkpoint(ff, weights_only=True) + else: + torch.save(dd, ff) + + +class train_callback(pl.Callback): + def __init__(self, args): + super().__init__() + self.args = args + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + args = self.args + # if args.cuda_cleanup > 0: + # torch.cuda.empty_cache() + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + + # LR schedule + w_step = args.warmup_steps + if args.lr_final == args.lr_init or args.epoch_count == 0: + lr = args.lr_init + else: + decay_step = real_step - args.my_pile_edecay * args.epoch_steps + decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps + progress = (decay_step - w_step + 1) / (decay_total - w_step) + progress = min(1, max(0, progress)) + + if args.lr_final == 0 or args.lr_init == 0: # linear decay + lr = args.lr_init + (args.lr_final - args.lr_init) * progress + else: # exp decay + lr = args.lr_init * math.exp( + math.log(args.lr_final / args.lr_init) * pow(progress, 1) + ) + # if trainer.is_global_zero: + # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) + + if args.my_exit_tokens != 0: # cosine decay + real_tokens = real_step * args.ctx_len * args.real_bsz + warmup_tokens = w_step * args.ctx_len * args.real_bsz + progress = (real_tokens - warmup_tokens) / ( + abs(args.my_exit_tokens) - warmup_tokens + ) + progress = max(0, min(1, progress)) + lr_final_factor = args.lr_final / args.lr_init + lr_mult = (0.5 + lr_final_factor / 2) + ( + 0.5 - lr_final_factor / 2 + ) * math.cos(math.pi * progress) + if args.my_exit_tokens > 0: + lr = args.lr_init * lr_mult + else: + lr = (lr + args.lr_init * lr_mult) / 2 + if progress >= 1: + if (trainer.is_global_zero) or ("deepspeed_stage_3" in args.strategy): + my_save( + args, + trainer, + pl_module.state_dict(), + f"{args.proj_dir}/rwkv-final.pth", + ) + exit(0) + if trainer.global_step < w_step: + lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) + + if args.weight_decay_final > 0: + wd_now = args.weight_decay * math.exp( + math.log(args.weight_decay_final / args.weight_decay) * progress + ) + else: + wd_now = args.weight_decay + + for param_group in trainer.optimizers[0].param_groups: + if param_group["weight_decay"] > 0: + param_group["weight_decay"] = wd_now + if args.layerwise_lr > 0: + param_group["lr"] = lr * param_group["my_lr_scale"] + # print(param_group["lr"], param_group["my_lr_scale"]) + else: + param_group["lr"] = lr + + trainer.my_lr = lr + trainer.my_wd = wd_now + # rank_zero_info(f"{real_step} {lr}") + + if trainer.global_step == 0: + if trainer.is_global_zero: # logging + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") + trainer.my_log.write( + f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n" + ) + try: + print(f"\n{trainer.strategy.config}\n") + trainer.my_log.write(f"{trainer.strategy.config}\n") + except: + pass + trainer.my_log.flush() + if len(args.wandb) > 0: + print("Login to wandb...") + import wandb + + wandb.init( + project=args.wandb, + name=args.run_name + " " + args.my_timestamp, + config=args, + save_code=False, + ) + trainer.my_wandb = wandb + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + args = self.args + token_per_step = args.ctx_len * args.real_bsz + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + if trainer.is_global_zero: # logging + t_now = time.time_ns() + kt_s = 0 + try: + t_cost = (t_now - trainer.my_time_ns) / 1e9 + kt_s = token_per_step / t_cost / 1000 + self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) + self.log("Kt/s", kt_s, prog_bar=True, on_step=True) + except: + pass + trainer.my_time_ns = t_now + if pl.__version__[0] == "2": + trainer.my_loss = outputs["loss"] + else: + trainer.my_loss = trainer.my_loss_all.float().mean().item() + trainer.my_loss_sum += trainer.my_loss + trainer.my_loss_count += 1 + trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count + self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) + self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) + # self.log("s", real_step, prog_bar=True, on_step=True) + + if len(args.wandb) > 0: + lll = { + "loss": trainer.my_loss, + "lr": trainer.my_lr, + "wd": trainer.my_wd, + "Gtokens": real_step * token_per_step / 1e9, + } + if kt_s > 0: + lll["kt/s"] = kt_s + trainer.my_wandb.log(lll, step=int(real_step)) + if (trainer.is_global_zero) or ( + "deepspeed_stage_3" in args.strategy + ): # save pth + if args.magic_prime > 0: + expand_factor = 2 if args.my_qa_mask > 0 else 1 + if int(real_step) == int( + args.magic_prime * expand_factor // args.real_bsz + ) - 1 + int(args.my_random_steps): + to_save_dict = pl_module.state_dict() + my_save( + args, + trainer, + to_save_dict, + f"{args.proj_dir}/rwkv-final.pth", + ) + # if args.batch_save==batch_idx : + # to_save_dict = pl_module.state_dict() + # for name, state in to_save_dict.items(): + # if 'img' in name: + # to_save_dict[name] = state + # try: + # my_save( + # args, trainer, + # to_save_dict, + # f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}-{batch_idx}.pth", + # ) + # except Exception as e: + # print('Error\n\n', e, '\n\n') + + def on_train_epoch_start(self, trainer, pl_module): + args = self.args + if pl.__version__[0] == "2": + dataset = trainer.train_dataloader.dataset + else: + dataset = trainer.train_dataloader.dataset.datasets + assert "MyDataset" in str(dataset) + dataset.global_rank = trainer.global_rank + dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) + dataset.world_size = trainer.world_size + # print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########') + + def on_train_epoch_end(self, trainer, pl_module): + args = self.args + to_save_dict = {} + if (trainer.is_global_zero) or ( + "deepspeed_stage_3" in args.strategy + ): # save pth + if ( + args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0 + ) or (trainer.current_epoch == args.epoch_count - 1): + if args.data_type == "wds_img": + raw_dict = pl_module.state_dict() + for k in raw_dict: + if k.startswith("encoder.") or k.startswith("decoder."): + to_save_dict[k] = raw_dict[k] + else: + to_save_dict = pl_module.state_dict() + + if args.data_type == "img" and not args.lora: + for name, state in to_save_dict.items(): + if "img" in name: + to_save_dict[name] = state + + if args.lora: + enable_time_finetune = "time" in LORA_CONFIG["parts"] + enable_ln_finetune = "ln" in LORA_CONFIG["parts"] + lora_dict = {} + for name, state in to_save_dict.items(): + if "img" in name: + lora_dict[name] = state + if ( + ".lora_" in name + or (enable_time_finetune and ".time_" in name) + or (enable_ln_finetune and ".ln" in name) + ): + lora_dict[name] = state + to_save_dict = lora_dict + + try: + my_save( + args, + trainer, + to_save_dict, + f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", + ) + except Exception as e: + print("Error\n\n", e, "\n\n") + + if trainer.is_global_zero: # logging + trainer.my_log.write( + f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n" + ) + trainer.my_log.flush() + + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + if (args.epoch_begin + trainer.current_epoch) >= args.my_exit: + exit(0) + + +@rank_zero_only +def generate_init_weight(model, init_weight_name): + mm = model.generate_init_weight() + + if model.args.my_pile_stage == 1: + if len(model.args.load_model) > 0: + print(f"Combine weights from {model.args.load_model}...") + load_dict = torch.load(model.args.load_model, map_location="cpu") + for k in load_dict: + try: + assert k in mm + except: + print("missing", k) + exit(0) + src = load_dict[k] + try: + mm[k] = src.reshape(mm[k].shape) + except: + tmp = mm[k].squeeze().clone() + print(k, src.shape, "-->", mm[k].shape) + ss = src.shape[0] + dd = tmp.shape[0] + for i in range(dd): + pos = i / dd * ss + if pos >= ss - 1: + tmp[i] = src[ss - 1] + else: + p0 = int(math.floor(pos)) + ii = pos - p0 + tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii) + mm[k] = tmp.reshape(mm[k].shape) + sss = src.squeeze().float().cpu().numpy() + print(sss[:10], "...", sss[-10:]) + mmm = mm[k].squeeze().float().cpu().numpy() + print(mmm[:10], "...", mmm[-10:]) + + print(f"Save to {init_weight_name}...") + torch.save(mm, init_weight_name) + + if model.args.my_pile_stage == 1: + print("Done. Now go for stage 2.") + exit(0) diff --git a/finetune/lora/v5/src/utils.py b/finetune/lora/v5/src/utils.py new file mode 100644 index 0000000..87da098 --- /dev/null +++ b/finetune/lora/v5/src/utils.py @@ -0,0 +1,139 @@ +import json, time, random, os +import numpy as np +import torch +from torch.nn import functional as F + +time_slot = {} +time_ref = time.time_ns() + + +def record_time(name): + if name not in time_slot: + time_slot[name] = 1e20 + tt = (time.time_ns() - time_ref) / 1e9 + if tt < time_slot[name]: + time_slot[name] = tt + + +class TOKENIZER: + def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"): + if "list" in str(type(WORD_NAME)): + self.charMode = False + if WORD_NAME[0] == WORD_NAME[1]: + from transformers import PreTrainedTokenizerFast + + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) + else: + from transformers import GPT2TokenizerFast + + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) + self.vocab_size = len(self.tokenizer) + else: + self.charMode = True + with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file: + self.word_table = json.load(result_file) + + self.vocab_size = len(self.word_table) + + self.stoi = {v: int(k) for k, v in self.word_table.items()} + self.itos = {int(k): v for k, v in self.word_table.items()} + + self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] + + def refine_context(self, context): + context = context.strip().split("\n") + for c in range(len(context)): + context[c] = context[c].strip().strip("\u3000").strip("\r") + context = list(filter(lambda c: c != "", context)) + context = "\n" + ("\n".join(context)).strip() + if context == "": + context = "\n" + return context + + def sample_logits( + self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None + ): + # out[self.UNKNOWN_CHAR] = -float('Inf') + lastChar = int(x[-1]) + + probs = F.softmax(out, dim=-1) + + if self.charMode: + if self.itos[lastChar] == "\n": + top_p = top_p_newline + else: + top_p = top_p_usual + else: + top_p = top_p_usual + + if os.environ["RWKV_RUN_DEVICE"] == "cpu": + probs = probs.numpy() + sorted_probs = np.sort(probs)[::-1] + cumulative_probs = np.cumsum(sorted_probs) + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + probs = probs / np.sum(probs) + out = np.random.choice(a=len(probs), p=probs) + return out + else: + sorted_probs = torch.sort(probs, descending=True)[0] + cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + out = torch.multinomial(probs, num_samples=1)[0] + return out + + +def MaybeIsPrime(number): + if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): + return True + else: + return False + + +def FermatPrimalityTest(number): + if number > 1: + for time in range(3): + randomNumber = random.randint(2, number) - 1 + if pow(randomNumber, number - 1, number) != 1: + return False + return True + else: + return False + + +def MillerRabinPrimalityTest(number): + if number == 2: + return True + elif number == 1 or number % 2 == 0: + return False + oddPartOfNumber = number - 1 + timesTwoDividNumber = 0 + while oddPartOfNumber % 2 == 0: + oddPartOfNumber = oddPartOfNumber // 2 + timesTwoDividNumber = timesTwoDividNumber + 1 + + for time in range(3): + while True: + randomNumber = random.randint(2, number) - 1 + if randomNumber != 0 and randomNumber != 1: + break + + randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) + + if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): + iterationNumber = 1 + + while (iterationNumber <= timesTwoDividNumber - 1) and ( + randomNumberWithPower != number - 1 + ): + randomNumberWithPower = pow(randomNumberWithPower, 2, number) + iterationNumber = iterationNumber + 1 + if randomNumberWithPower != (number - 1): + return False + + return True diff --git a/finetune/lora/v5/train.py b/finetune/lora/v5/train.py new file mode 100644 index 0000000..b41cdbd --- /dev/null +++ b/finetune/lora/v5/train.py @@ -0,0 +1,436 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import logging + +logging.basicConfig(level=logging.INFO) + +if __name__ == "__main__": + from argparse import ArgumentParser + from pytorch_lightning import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + import pytorch_lightning as pl + + rank_zero_info("########## work in progress ##########") + + parser = ArgumentParser() + + parser.add_argument("--load_model", default="", type=str) # full path, with .pth + parser.add_argument( + "--wandb", default="", type=str + ) # wandb project name. if "" then don't use wandb + parser.add_argument("--proj_dir", default="out", type=str) + parser.add_argument("--random_seed", default="-1", type=int) + + parser.add_argument("--data_file", default="", type=str) + parser.add_argument("--data_type", default="utf-8", type=str) + parser.add_argument( + "--vocab_size", default=0, type=int + ) # vocab_size = 0 means auto (for char-level LM and .txt data) + + parser.add_argument("--ctx_len", default=1024, type=int) + parser.add_argument( + "--epoch_steps", default=1000, type=int + ) # a mini "epoch" has [epoch_steps] steps + parser.add_argument( + "--epoch_count", default=500, type=int + ) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument( + "--epoch_begin", default=0, type=int + ) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument( + "--epoch_save", default=5, type=int + ) # save the model every [epoch_save] "epochs" + + parser.add_argument( + "--micro_bsz", default=12, type=int + ) # micro batch size (batch size per GPU) + parser.add_argument("--n_layer", default=6, type=int) + parser.add_argument("--n_embd", default=512, type=int) + parser.add_argument("--dim_att", default=0, type=int) + parser.add_argument("--dim_ffn", default=0, type=int) + parser.add_argument( + "--pre_ffn", default=0, type=int + ) # replace first att layer by ffn (sometimes better) + parser.add_argument("--head_qk", default=0, type=int) # my headQK trick + parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim + parser.add_argument( + "--tiny_att_layer", default=-999, type=int + ) # tiny attention @ which layer + + parser.add_argument( + "--lr_init", default=6e-4, type=float + ) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument("--lr_final", default=1e-5, type=float) + parser.add_argument( + "--warmup_steps", default=-1, type=int + ) # try 50 if you load a model + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument( + "--beta2", default=0.99, type=float + ) # use 0.999 when your model is close to convergence + parser.add_argument("--adam_eps", default=1e-8, type=float) + parser.add_argument( + "--grad_cp", default=0, type=int + ) # gradient checkpt: saves VRAM, but slower + parser.add_argument( + "--dropout", default=0, type=float + ) # try 0.01 / 0.02 / 0.05 / 0.1 + parser.add_argument( + "--weight_decay", default=0, type=float + ) # try 0.1 / 0.01 / 0.001 + parser.add_argument("--weight_decay_final", default=-1, type=float) + + parser.add_argument( + "--my_pile_version", default=1, type=int + ) # my special pile version + parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode + parser.add_argument( + "--my_pile_shift", default=-1, type=int + ) # my special pile mode - text shift + parser.add_argument("--my_pile_edecay", default=0, type=int) + parser.add_argument( + "--layerwise_lr", default=1, type=int + ) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument( + "--ds_bucket_mb", default=200, type=int + ) # deepspeed bucket size in MB. 200 seems enough + # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) + + parser.add_argument("--my_sample_len", default=0, type=int) + parser.add_argument("--my_ffn_shift", default=1, type=int) + parser.add_argument("--my_att_shift", default=1, type=int) + parser.add_argument( + "--head_size_a", default=64, type=int + ) # can try larger values for larger models + parser.add_argument("--head_size_divisor", default=8, type=int) + parser.add_argument("--my_pos_emb", default=0, type=int) + parser.add_argument("--load_partial", default=0, type=int) + parser.add_argument("--magic_prime", default=0, type=int) + parser.add_argument("--my_qa_mask", default=0, type=int) + parser.add_argument("--my_random_steps", default=0, type=int) + parser.add_argument("--my_testing", default="", type=str) + parser.add_argument("--my_exit", default=99999999, type=int) + parser.add_argument("--my_exit_tokens", default=0, type=int) + + # LORA + parser.add_argument("--emb", action="store_true") + parser.add_argument("--lora", action="store_true") + parser.add_argument("--lora_load", default="", type=str) + parser.add_argument("--lora_r", default=8, type=int) + parser.add_argument("--lora_alpha", default=32, type=float) + parser.add_argument("--lora_dropout", default=0.01, type=float) + parser.add_argument("--lora_parts", default="att,ln,time", type=str) + + if pl.__version__[0] == "2": + parser.add_argument("--accelerator", default="gpu", type=str) + parser.add_argument("--strategy", default="auto", type=str) + parser.add_argument("--devices", default=1, type=int) + parser.add_argument("--num_nodes", default=1, type=int) + parser.add_argument("--precision", default="fp16", type=str) + parser.add_argument("--accumulate_grad_batches", default=1, type=int) + else: + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args() + + ######################################################################################################## + + import os, warnings, math, datetime, sys, time + import numpy as np + import torch + from torch.utils.data import DataLoader + + if "deepspeed" in args.strategy: + import deepspeed + from pytorch_lightning import seed_everything + + if args.random_seed >= 0: + print( + f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" + * 3 + ) + seed_everything(args.random_seed) + + np.set_printoptions(precision=4, suppress=True, linewidth=200) + warnings.filterwarnings( + "ignore", ".*Consider increasing the value of the `num_workers` argument*" + ) + warnings.filterwarnings( + "ignore", ".*The progress bar already tracks a metric with the*" + ) + # os.environ["WDS_SHOW_SEED"] = "1" + + args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + args.enable_checkpointing = False + args.replace_sampler_ddp = False + args.logger = False + args.gradient_clip_val = 1.0 + args.num_sanity_val_steps = 0 + args.check_val_every_n_epoch = int(1e20) + args.log_every_n_steps = int(1e20) + args.max_epochs = args.epoch_count # -1 continue forever + args.betas = (args.beta1, args.beta2) + args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz + os.environ["RWKV_MY_TESTING"] = args.my_testing + os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a) + if args.dim_att <= 0: + args.dim_att = args.n_embd + if args.dim_ffn <= 0: + args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size + + if args.data_type == "wds_img": + args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" + args.proj_dir = f"{args.proj_dir}-{args.run_name}" + else: + args.run_name = ( + f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + ) + if not os.path.exists(args.proj_dir): + os.makedirs(args.proj_dir) + + if args.my_pile_stage > 0: + magic_prime_bak = args.magic_prime + + if args.my_pile_shift < 0: + args.my_pile_shift = 0 + + if magic_prime_bak > 0: + args.magic_prime = magic_prime_bak + if args.my_qa_mask == 2: + args.epoch_count = 2 * args.magic_prime // 40320 + else: + args.epoch_count = args.magic_prime // 40320 + + args.epoch_steps = 40320 // args.real_bsz + assert args.epoch_steps * args.real_bsz == 40320 + # if args.my_pile_stage == 2: + # assert args.lr_final == args.lr_init + if args.my_pile_stage >= 2: # find latest saved model + list_p = [] + for p in os.listdir(args.proj_dir): + if p.startswith("rwkv") and p.endswith(".pth"): + p = ((p.split("-"))[1].split("."))[0] + if p != "final": + if p == "init": + p = -1 + else: + p = int(p) + list_p += [p] + list_p.sort() + max_p = list_p[-1] + if len(list_p) > 1: + args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted + if max_p == -1: + args.load_model = f"{args.proj_dir}/rwkv-init.pth" + else: + args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" + if args.warmup_steps < 0: + if args.my_pile_stage == 2: + args.warmup_steps = 10 + else: + args.warmup_steps = 30 + args.epoch_begin = max_p + 1 + + samples_per_epoch = args.epoch_steps * args.real_bsz + tokens_per_epoch = samples_per_epoch * args.ctx_len + try: + deepspeed_version = deepspeed.__version__ + except: + deepspeed_version = None + pass + rank_zero_info( + f""" +############################################################################ +# +# RWKV-5 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} +# +# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} +# +# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1}, save every {args.epoch_save} epoch +# +# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens +# +# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len +# +# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} +# +# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer +# Found deepspeed {deepspeed_version}, recommend 0.7.0 (faster than newer versions) +# Found pytorch_lightning {pl.__version__}, recommend 1.9.5 +# +############################################################################ +""" + ) + rank_zero_info(str(vars(args)) + "\n") + + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "uint16"] + + if args.lr_final == 0 or args.lr_init == 0: + rank_zero_info( + "\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n" + ) + + assert args.precision in ["fp32", "tf32", "fp16", "bf16"] + os.environ["RWKV_FLOAT_MODE"] = args.precision + if args.precision == "fp32": + for i in range(10): + rank_zero_info( + "\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n" + ) + if args.precision == "fp16": + rank_zero_info( + "\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n" + ) + + os.environ["RWKV_JIT_ON"] = "0" + if "deepspeed_stage_3" in args.strategy: + os.environ["RWKV_JIT_ON"] = "0" + + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + if args.precision == "fp32": + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + else: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + if "32" in args.precision: + args.precision = 32 + elif args.precision == "fp16": + args.precision = 16 + else: + args.precision = "bf16" + + ######################################################################################################## + + from src.trainer import train_callback, generate_init_weight + from src.dataset import MyDataset + + train_data = MyDataset(args) + args.vocab_size = train_data.vocab_size + + from src.model import RWKV, LORA_CONFIG, LoraLinear + + if args.lora: + assert args.lora_r > 0, "LoRA should have its `r` > 0" + LORA_CONFIG["r"] = args.lora_r + LORA_CONFIG["alpha"] = args.lora_alpha + LORA_CONFIG["dropout"] = args.lora_dropout + LORA_CONFIG["parts"] = set(str(args.lora_parts).split(",")) + enable_time_finetune = "time" in LORA_CONFIG["parts"] + enable_ln_finetune = "ln" in LORA_CONFIG["parts"] + model = RWKV(args) + # only train lora parameters + if args.lora: + model.requires_grad_(False) + for name, module in model.named_modules(): + if any(n.startswith("lora_") for n, _ in module.named_parameters()): + print(f" LoRA additionally training module {name}") + for pname, param in module.named_parameters(): + param.requires_grad = "lora_" in pname + elif enable_ln_finetune and ".ln" in name: + print(f" LoRA additionally training module {name}") + for param in module.parameters(): + param.requires_grad = True + elif enable_time_finetune and any( + n.startswith("time") for n, _ in module.named_parameters() + ): + for pname, param in module.named_parameters(): + if pname.startswith("time"): + print(f" LoRA additionally training parameter {pname}") + param.requires_grad = True + + if ( + len(args.load_model) == 0 or args.my_pile_stage == 1 + ): # shall we build the initial weights? + init_weight_name = f"{args.proj_dir}/rwkv-init.pth" + generate_init_weight(model, init_weight_name) # save initial weights + args.load_model = init_weight_name + + rank_zero_info(f"########## Loading {args.load_model}... ##########") + try: + load_dict = torch.load(args.load_model, map_location="cpu") + load_keys = list(load_dict.keys()) + for k in load_keys: + if k.startswith("_forward_module."): + load_dict[k.replace("_forward_module.", "")] = load_dict[k] + del load_dict[k] + except: + rank_zero_info(f"Bad checkpoint {args.load_model}") + if args.my_pile_stage >= 2: # try again using another checkpoint + max_p = args.my_pile_prev_p + if max_p == -1: + args.load_model = f"{args.proj_dir}/rwkv-init.pth" + else: + args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" + args.epoch_begin = max_p + 1 + rank_zero_info(f"Trying {args.load_model}") + load_dict = torch.load(args.load_model, map_location="cpu") + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in model.state_dict(): + if k not in load_keys: + load_dict[k] = model.state_dict()[k] + # model.load_state_dict(load_dict) + + model.load_state_dict(load_dict, strict=(not args.lora)) + if os.path.isfile(args.lora_load): + model.load_state_dict( + torch.load(args.lora_load, map_location="cpu"), strict=False + ) + + if pl.__version__[0] == "2": + trainer = Trainer( + accelerator=args.accelerator, + strategy=args.strategy, + devices=args.devices, + num_nodes=args.num_nodes, + precision=args.precision, + logger=args.logger, + callbacks=[train_callback(args)], + max_epochs=args.max_epochs, + check_val_every_n_epoch=args.check_val_every_n_epoch, + num_sanity_val_steps=args.num_sanity_val_steps, + log_every_n_steps=args.log_every_n_steps, + enable_checkpointing=args.enable_checkpointing, + accumulate_grad_batches=args.accumulate_grad_batches, + gradient_clip_val=args.gradient_clip_val, + ) + else: + trainer = Trainer.from_argparse_args( + args, + callbacks=[train_callback(args)], + ) + + if trainer.global_rank == 0: + for n in model.state_dict(): + shape = model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) + + # must set shuffle=False, persistent_workers=False (because worker is in another thread) + data_loader = DataLoader( + train_data, + shuffle=False, + pin_memory=True, + batch_size=args.micro_bsz, + num_workers=1, + persistent_workers=False, + drop_last=True, + ) + + trainer.fit(model, data_loader) diff --git a/frontend/src/pages/Train.tsx b/frontend/src/pages/Train.tsx index 23e7ccb..cc4aff2 100644 --- a/frontend/src/pages/Train.tsx +++ b/frontend/src/pages/Train.tsx @@ -131,7 +131,7 @@ const showError = (e: any) => { }; const errorsMap = Object.entries({ - 'python3 ./finetune/lora/train.py': 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.', + 'python3 ./finetune/lora/v': 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.', 'cuda out of memory': 'VRAM is not enough', 'valueerror: high <= 0': 'Training data is not enough, reduce context length or add more data for training', '+= \'+ptx\'': 'Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"', @@ -299,7 +299,6 @@ const LoraFinetune: FC = observer(() => { (loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') + (loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') + `--data_file ${convertedDataPath} ` + - `--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` + `--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` + `--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` + `--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +