rwkv5 lora finetune support (https://github.com/JL-er/RWKV-v5-lora)
This commit is contained in:
		
							parent
							
								
									b7f4dd835e
								
							
						
					
					
						commit
						81544ca8b3
					
				@ -32,6 +32,7 @@ cleaner_thread.start()
 | 
				
			|||||||
w = torch.load(model_file, map_location="cpu")
 | 
					w = torch.load(model_file, map_location="cpu")
 | 
				
			||||||
gc.collect()
 | 
					gc.collect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					vocab_size = w["emb.weight"].shape[0]
 | 
				
			||||||
n_embd = w["emb.weight"].shape[1]
 | 
					n_embd = w["emb.weight"].shape[1]
 | 
				
			||||||
n_layer = 0
 | 
					n_layer = 0
 | 
				
			||||||
keys = list(w.keys())
 | 
					keys = list(w.keys())
 | 
				
			||||||
@ -52,6 +53,9 @@ for x in keys:
 | 
				
			|||||||
        version = max(6, version)
 | 
					        version = max(6, version)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if version <= expected_max_version:
 | 
					if version <= expected_max_version:
 | 
				
			||||||
    print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
 | 
					    print(
 | 
				
			||||||
 | 
					        f"v{int(version)}/train.py --vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}",
 | 
				
			||||||
 | 
					        end="",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
else:
 | 
					else:
 | 
				
			||||||
    raise Exception(f"RWKV{version} is not supported")
 | 
					    raise Exception(f"RWKV{version} is not supported")
 | 
				
			||||||
 | 
				
			|||||||
@ -47,10 +47,10 @@ else
 | 
				
			|||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
echo "loading $loadModel"
 | 
					echo "loading $loadModel"
 | 
				
			||||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 4)
 | 
					modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 5.2)
 | 
				
			||||||
echo $modelInfo
 | 
					echo $modelInfo
 | 
				
			||||||
if [[ $modelInfo =~ "--n_layer" ]]; then
 | 
					if [[ $modelInfo =~ "--n_layer" ]]; then
 | 
				
			||||||
  python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
 | 
					  python3 ./finetune/lora/$modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
 | 
				
			||||||
    --lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu
 | 
					    --lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu
 | 
				
			||||||
else
 | 
					else
 | 
				
			||||||
  echo "modelInfo is invalid"
 | 
					  echo "modelInfo is invalid"
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,7 @@ import struct
 | 
				
			|||||||
from functools import lru_cache
 | 
					from functools import lru_cache
 | 
				
			||||||
from itertools import accumulate
 | 
					from itertools import accumulate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def print_rank_0(*message):
 | 
					def print_rank_0(*message):
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
    # """If distributed is initialized print only on rank 0."""
 | 
					    # """If distributed is initialized print only on rank 0."""
 | 
				
			||||||
@ -16,12 +17,14 @@ def print_rank_0(*message):
 | 
				
			|||||||
    # else:
 | 
					    # else:
 | 
				
			||||||
    #     print(*message, flush=True)
 | 
					    #     print(*message, flush=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _warmup_mmap_file(path):
 | 
					def _warmup_mmap_file(path):
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
    # with open(path, "rb") as stream:
 | 
					    # with open(path, "rb") as stream:
 | 
				
			||||||
    #     while stream.read(100 * 1024 * 1024):
 | 
					    #     while stream.read(100 * 1024 * 1024):
 | 
				
			||||||
    #         pass
 | 
					    #         pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
dtypes = {
 | 
					dtypes = {
 | 
				
			||||||
    1: np.uint8,
 | 
					    1: np.uint8,
 | 
				
			||||||
    2: np.int8,
 | 
					    2: np.int8,
 | 
				
			||||||
@ -33,18 +36,22 @@ dtypes = {
 | 
				
			|||||||
    8: np.uint16,
 | 
					    8: np.uint16,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def code(dtype):
 | 
					def code(dtype):
 | 
				
			||||||
    for k in dtypes.keys():
 | 
					    for k in dtypes.keys():
 | 
				
			||||||
        if dtypes[k] == dtype:
 | 
					        if dtypes[k] == dtype:
 | 
				
			||||||
            return k
 | 
					            return k
 | 
				
			||||||
    raise ValueError(dtype)
 | 
					    raise ValueError(dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def index_file_path(prefix_path):
 | 
					def index_file_path(prefix_path):
 | 
				
			||||||
    return prefix_path + ".idx"
 | 
					    return prefix_path + ".idx"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def data_file_path(prefix_path):
 | 
					def data_file_path(prefix_path):
 | 
				
			||||||
    return prefix_path + ".bin"
 | 
					    return prefix_path + ".bin"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MMapIndexedDataset(torch.utils.data.Dataset):
 | 
					class MMapIndexedDataset(torch.utils.data.Dataset):
 | 
				
			||||||
    class Index(object):
 | 
					    class Index(object):
 | 
				
			||||||
        _HDR_MAGIC = b"MMIDIDX\x00\x00"
 | 
					        _HDR_MAGIC = b"MMIDIDX\x00\x00"
 | 
				
			||||||
@ -217,8 +224,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
 | 
				
			|||||||
        elif isinstance(idx, slice):
 | 
					        elif isinstance(idx, slice):
 | 
				
			||||||
            start, stop, step = idx.indices(len(self))
 | 
					            start, stop, step = idx.indices(len(self))
 | 
				
			||||||
            if step != 1:
 | 
					            if step != 1:
 | 
				
			||||||
                raise ValueError(
 | 
					                raise ValueError("Slices into indexed_dataset must be contiguous")
 | 
				
			||||||
                    "Slices into indexed_dataset must be contiguous")
 | 
					 | 
				
			||||||
            ptr = self._index._pointers[start]
 | 
					            ptr = self._index._pointers[start]
 | 
				
			||||||
            sizes = self._index._sizes[idx]
 | 
					            sizes = self._index._sizes[idx]
 | 
				
			||||||
            offsets = list(accumulate(sizes))
 | 
					            offsets = list(accumulate(sizes))
 | 
				
			||||||
@ -17,9 +17,11 @@ class MyDataset(Dataset):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        if args.data_type == "binidx":
 | 
					        if args.data_type == "binidx":
 | 
				
			||||||
            self.vocab_size = args.vocab_size
 | 
					            self.vocab_size = args.vocab_size
 | 
				
			||||||
            rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
 | 
					            rank_zero_info(
 | 
				
			||||||
 | 
					                f"Current vocab size = {self.vocab_size} (make sure it's correct)"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if args.data_file.endswith('/'):
 | 
					            if args.data_file.endswith("/"):
 | 
				
			||||||
                d_all = []
 | 
					                d_all = []
 | 
				
			||||||
                for p in os.listdir(args.data_file):
 | 
					                for p in os.listdir(args.data_file):
 | 
				
			||||||
                    if p.endswith(".idx"):
 | 
					                    if p.endswith(".idx"):
 | 
				
			||||||
@ -29,33 +31,52 @@ class MyDataset(Dataset):
 | 
				
			|||||||
                exit(0)
 | 
					                exit(0)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                self.data = MMapIndexedDataset(args.data_file)
 | 
					                self.data = MMapIndexedDataset(args.data_file)
 | 
				
			||||||
                self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
 | 
					                self.data_size = (
 | 
				
			||||||
 | 
					                    len(self.data._bin_buffer) // self.data._index._dtype_size
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
                rank_zero_info(f"Data has {self.data_size} tokens.")
 | 
					                rank_zero_info(f"Data has {self.data_size} tokens.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if args.my_qa_mask > 0:
 | 
					            if args.my_qa_mask > 0:
 | 
				
			||||||
                self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
 | 
					                self.data_pile = MMapIndexedDataset(
 | 
				
			||||||
                self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
 | 
					                    "/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                self.data_pile_size = (
 | 
				
			||||||
 | 
					                    len(self.data_pile._bin_buffer) // self.data._index._dtype_size
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if args.my_pile_stage > 0:
 | 
					            if args.my_pile_stage > 0:
 | 
				
			||||||
                # assert self.data_size == 332115325534 and self.vocab_size == 50277
 | 
					                # assert self.data_size == 332115325534 and self.vocab_size == 50277
 | 
				
			||||||
                self.samples_per_epoch = args.epoch_steps * args.real_bsz
 | 
					                self.samples_per_epoch = args.epoch_steps * args.real_bsz
 | 
				
			||||||
                assert self.samples_per_epoch == 40320
 | 
					                assert self.samples_per_epoch == 40320
 | 
				
			||||||
                rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
 | 
					                rank_zero_info(
 | 
				
			||||||
 | 
					                    f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
                dataset_slot = self.data_size // args.ctx_len
 | 
					                dataset_slot = self.data_size // args.ctx_len
 | 
				
			||||||
                if args.my_pile_stage != 4:
 | 
					                if args.my_pile_stage != 4:
 | 
				
			||||||
                    assert MaybeIsPrime(args.magic_prime)
 | 
					                    assert MaybeIsPrime(args.magic_prime)
 | 
				
			||||||
                    assert args.magic_prime % 3 == 2
 | 
					                    assert args.magic_prime % 3 == 2
 | 
				
			||||||
                    assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
 | 
					                    assert (
 | 
				
			||||||
 | 
					                        args.magic_prime / dataset_slot > 0.99
 | 
				
			||||||
 | 
					                        and args.magic_prime / dataset_slot <= 1
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
        elif args.data_type == "numpy":
 | 
					        elif args.data_type == "numpy":
 | 
				
			||||||
            self.data = np.load(args.data_file).astype("int")
 | 
					            self.data = np.load(args.data_file).astype("int")
 | 
				
			||||||
            self.vocab_size = args.vocab_size
 | 
					            self.vocab_size = args.vocab_size
 | 
				
			||||||
            rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
 | 
					            rank_zero_info(
 | 
				
			||||||
 | 
					                "Current vocab size =", self.vocab_size, "(make sure it's correct)"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            self.data_size = len(self.data)
 | 
					            self.data_size = len(self.data)
 | 
				
			||||||
            rank_zero_info(f"Data has {self.data_size} tokens.")
 | 
					            rank_zero_info(f"Data has {self.data_size} tokens.")
 | 
				
			||||||
        elif args.data_type == "uint16":
 | 
					        elif args.data_type == "uint16":
 | 
				
			||||||
            self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
 | 
					            self.data = (
 | 
				
			||||||
 | 
					                np.fromfile(args.data_file, dtype=np.uint16)
 | 
				
			||||||
 | 
					                .astype("int32")
 | 
				
			||||||
 | 
					                .reshape(-1, args.my_sample_len)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            self.vocab_size = args.vocab_size
 | 
					            self.vocab_size = args.vocab_size
 | 
				
			||||||
            rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
 | 
					            rank_zero_info(
 | 
				
			||||||
 | 
					                "Current vocab size =", self.vocab_size, "(make sure it's correct)"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            self.data_size = self.data.shape[0]
 | 
					            self.data_size = self.data.shape[0]
 | 
				
			||||||
            rank_zero_info(f"Data has {self.data_size} samples.")
 | 
					            rank_zero_info(f"Data has {self.data_size} samples.")
 | 
				
			||||||
        elif args.data_type == "wds_img":
 | 
					        elif args.data_type == "wds_img":
 | 
				
			||||||
@ -86,10 +107,14 @@ class MyDataset(Dataset):
 | 
				
			|||||||
            for u in unique:
 | 
					            for u in unique:
 | 
				
			||||||
                xxObj[xx] = u
 | 
					                xxObj[xx] = u
 | 
				
			||||||
                xx += 1
 | 
					                xx += 1
 | 
				
			||||||
            with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
 | 
					            with open(
 | 
				
			||||||
 | 
					                f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le"
 | 
				
			||||||
 | 
					            ) as vocab_file:
 | 
				
			||||||
                vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
 | 
					                vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
 | 
				
			||||||
            self.data_size = len(self.data)
 | 
					            self.data_size = len(self.data)
 | 
				
			||||||
            rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
 | 
					            rank_zero_info(
 | 
				
			||||||
 | 
					                f"Data has {self.data_size} tokens, {self.vocab_size} vocab size."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            self.stoi = {ch: i for i, ch in enumerate(unique)}
 | 
					            self.stoi = {ch: i for i, ch in enumerate(unique)}
 | 
				
			||||||
            self.itos = {i: ch for i, ch in enumerate(unique)}
 | 
					            self.itos = {i: ch for i, ch in enumerate(unique)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -104,36 +129,53 @@ class MyDataset(Dataset):
 | 
				
			|||||||
        # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
 | 
					        # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if args.data_type == "wds_img":
 | 
					        if args.data_type == "wds_img":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            def init_wds(self, bias=0):
 | 
					            def init_wds(self, bias=0):
 | 
				
			||||||
                def identity(x):
 | 
					                def identity(x):
 | 
				
			||||||
                    return x
 | 
					                    return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                import webdataset as wds
 | 
					                import webdataset as wds
 | 
				
			||||||
                import torchvision.transforms as transforms
 | 
					                import torchvision.transforms as transforms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # img_transform = transforms.Compose(
 | 
					                # img_transform = transforms.Compose(
 | 
				
			||||||
                #     [transforms.CenterCrop(256)]
 | 
					                #     [transforms.CenterCrop(256)]
 | 
				
			||||||
                # )
 | 
					                # )
 | 
				
			||||||
                img_transform = transforms.Compose([
 | 
					                img_transform = transforms.Compose(
 | 
				
			||||||
                    transforms.CenterCrop(512),
 | 
					                    [transforms.CenterCrop(512), transforms.Resize((args.my_img_size))]
 | 
				
			||||||
                    transforms.Resize((args.my_img_size))
 | 
					                )
 | 
				
			||||||
                ])
 | 
					                self.data_raw = (
 | 
				
			||||||
                self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
 | 
					                    wds.WebDataset(args.data_file, resampled=True)
 | 
				
			||||||
 | 
					                    .shuffle(
 | 
				
			||||||
 | 
					                        10000,
 | 
				
			||||||
 | 
					                        initial=1000,
 | 
				
			||||||
 | 
					                        rng=random.Random(epoch * 100000 + rank + bias * 1e9),
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    .decode("torchrgb")
 | 
				
			||||||
 | 
					                    .to_tuple("jpg", "json", "txt")
 | 
				
			||||||
 | 
					                    .map_tuple(img_transform, identity, identity)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
                for pp in self.data_raw.pipeline:
 | 
					                for pp in self.data_raw.pipeline:
 | 
				
			||||||
                    if 'Resampled' in str(pp):
 | 
					                    if "Resampled" in str(pp):
 | 
				
			||||||
                        pp.deterministic = True
 | 
					                        pp.deterministic = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        def worker_seed():
 | 
					                        def worker_seed():
 | 
				
			||||||
                            return rank*100000+epoch+bias*1e9
 | 
					                            return rank * 100000 + epoch + bias * 1e9
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        pp.worker_seed = worker_seed
 | 
					                        pp.worker_seed = worker_seed
 | 
				
			||||||
                self.data = iter(self.data_raw)
 | 
					                self.data = iter(self.data_raw)
 | 
				
			||||||
                # print(f"WebDataset loaded for rank {rank} epoch {epoch}")
 | 
					                # print(f"WebDataset loaded for rank {rank} epoch {epoch}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if self.data == None:
 | 
					            if self.data == None:
 | 
				
			||||||
                init_wds(self)
 | 
					                init_wds(self)
 | 
				
			||||||
            trial = 0
 | 
					            trial = 0
 | 
				
			||||||
            while trial < 10:
 | 
					            while trial < 10:
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    dd = next(self.data) # jpg, json, txt
 | 
					                    dd = next(self.data)  # jpg, json, txt
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
                except:
 | 
					                except:
 | 
				
			||||||
                    print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
 | 
					                    print(
 | 
				
			||||||
 | 
					                        f"[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]"
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
                    self.error_count += 1
 | 
					                    self.error_count += 1
 | 
				
			||||||
                    init_wds(self, self.error_count)
 | 
					                    init_wds(self, self.error_count)
 | 
				
			||||||
                    trial += 1
 | 
					                    trial += 1
 | 
				
			||||||
@ -144,7 +186,7 @@ class MyDataset(Dataset):
 | 
				
			|||||||
            return dd[0], dd[2]
 | 
					            return dd[0], dd[2]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if args.data_type == "uint16":
 | 
					            if args.data_type == "uint16":
 | 
				
			||||||
                i = np.random.randint(0, self.data_size-1)
 | 
					                i = np.random.randint(0, self.data_size - 1)
 | 
				
			||||||
                dix = self.data[i]
 | 
					                dix = self.data[i]
 | 
				
			||||||
                x = torch.tensor(dix[:-1], dtype=torch.long)
 | 
					                x = torch.tensor(dix[:-1], dtype=torch.long)
 | 
				
			||||||
                y = torch.tensor(dix[1:], dtype=torch.long)
 | 
					                y = torch.tensor(dix[1:], dtype=torch.long)
 | 
				
			||||||
@ -196,7 +238,12 @@ class MyDataset(Dataset):
 | 
				
			|||||||
                        z_sum = 0
 | 
					                        z_sum = 0
 | 
				
			||||||
                        isGood = False
 | 
					                        isGood = False
 | 
				
			||||||
                        for i in range(3, ctx_len):
 | 
					                        for i in range(3, ctx_len):
 | 
				
			||||||
                            if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
 | 
					                            if (
 | 
				
			||||||
 | 
					                                dix[i] == 27
 | 
				
			||||||
 | 
					                                and dix[i - 1] == 34
 | 
				
			||||||
 | 
					                                and dix[i - 2] == 187
 | 
				
			||||||
 | 
					                                and dix[i - 3] == 187
 | 
				
			||||||
 | 
					                            ):
 | 
				
			||||||
                                isGood = True
 | 
					                                isGood = True
 | 
				
			||||||
                            if dix[i] == 0:
 | 
					                            if dix[i] == 0:
 | 
				
			||||||
                                isGood = False
 | 
					                                isGood = False
 | 
				
			||||||
@ -206,7 +253,9 @@ class MyDataset(Dataset):
 | 
				
			|||||||
                        if z_sum == 0:
 | 
					                        if z_sum == 0:
 | 
				
			||||||
                            z = [1] * ctx_len
 | 
					                            z = [1] * ctx_len
 | 
				
			||||||
                            i = np.random.randint(0, self.data_pile_size - req_len)
 | 
					                            i = np.random.randint(0, self.data_pile_size - req_len)
 | 
				
			||||||
                            dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
 | 
					                            dix = self.data_pile.get(
 | 
				
			||||||
 | 
					                                idx=0, offset=i, length=req_len
 | 
				
			||||||
 | 
					                            ).astype(int)
 | 
				
			||||||
                    z = torch.tensor(z, dtype=torch.bfloat16)
 | 
					                    z = torch.tensor(z, dtype=torch.bfloat16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                x = torch.tensor(dix[:-1], dtype=torch.long)
 | 
					                x = torch.tensor(dix[:-1], dtype=torch.long)
 | 
				
			||||||
@ -5,6 +5,7 @@
 | 
				
			|||||||
import functools
 | 
					import functools
 | 
				
			||||||
import os, math, gc, importlib
 | 
					import os, math, gc, importlib
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# torch._C._jit_set_profiling_executor(True)
 | 
					# torch._C._jit_set_profiling_executor(True)
 | 
				
			||||||
# torch._C._jit_set_profiling_mode(True)
 | 
					# torch._C._jit_set_profiling_mode(True)
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
@ -13,7 +14,8 @@ from torch.nn import functional as F
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
 | 
					from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
 | 
				
			||||||
from pytorch_lightning.strategies import DeepSpeedStrategy
 | 
					from pytorch_lightning.strategies import DeepSpeedStrategy
 | 
				
			||||||
if importlib.util.find_spec('deepspeed'):
 | 
					
 | 
				
			||||||
 | 
					if importlib.util.find_spec("deepspeed"):
 | 
				
			||||||
    import deepspeed
 | 
					    import deepspeed
 | 
				
			||||||
    from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
 | 
					    from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,9 +30,10 @@ LORA_CONFIG = {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
 | 
					    print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"])
 | 
				
			||||||
except:
 | 
					except:
 | 
				
			||||||
    os.environ["RWKV_MY_TESTING"] = ''
 | 
					    os.environ["RWKV_MY_TESTING"] = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def __nop(ob):
 | 
					def __nop(ob):
 | 
				
			||||||
    return ob
 | 
					    return ob
 | 
				
			||||||
@ -53,7 +56,26 @@ T_MAX = int(os.environ["RWKV_T_MAX"])  # TAKES LOTS OF VRAM!
 | 
				
			|||||||
from torch.utils.cpp_extension import load
 | 
					from torch.utils.cpp_extension import load
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
					if os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
				
			||||||
    wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["finetune/lora/cuda/wkv_op_bf16.cpp", "finetune/lora/cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
 | 
					    wkv_cuda = load(
 | 
				
			||||||
 | 
					        name=f"wkv_{T_MAX}_bf16",
 | 
				
			||||||
 | 
					        sources=[
 | 
				
			||||||
 | 
					            "finetune/lora/v4/cuda/wkv_op_bf16.cpp",
 | 
				
			||||||
 | 
					            "finetune/lora/v4/cuda/wkv_cuda_bf16.cu",
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        verbose=True,
 | 
				
			||||||
 | 
					        extra_cuda_cflags=[
 | 
				
			||||||
 | 
					            "-t 4",
 | 
				
			||||||
 | 
					            "-std=c++17",
 | 
				
			||||||
 | 
					            "-res-usage",
 | 
				
			||||||
 | 
					            "--maxrregcount 60",
 | 
				
			||||||
 | 
					            "--use_fast_math",
 | 
				
			||||||
 | 
					            "-O3",
 | 
				
			||||||
 | 
					            "-Xptxas -O3",
 | 
				
			||||||
 | 
					            "--extra-device-vectorization",
 | 
				
			||||||
 | 
					            f"-DTmax={T_MAX}",
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class WKV(torch.autograd.Function):
 | 
					    class WKV(torch.autograd.Function):
 | 
				
			||||||
        @staticmethod
 | 
					        @staticmethod
 | 
				
			||||||
        def forward(ctx, B, T, C, w, u, k, v):
 | 
					        def forward(ctx, B, T, C, w, u, k, v):
 | 
				
			||||||
@ -66,10 +88,16 @@ if os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
				
			|||||||
            u = u.contiguous()
 | 
					            u = u.contiguous()
 | 
				
			||||||
            k = k.contiguous()
 | 
					            k = k.contiguous()
 | 
				
			||||||
            v = v.contiguous()
 | 
					            v = v.contiguous()
 | 
				
			||||||
            y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
 | 
					            y = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C),
 | 
				
			||||||
 | 
					                device=w.device,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            wkv_cuda.forward(B, T, C, w, u, k, v, y)
 | 
					            wkv_cuda.forward(B, T, C, w, u, k, v, y)
 | 
				
			||||||
            ctx.save_for_backward(w, u, k, v, y)
 | 
					            ctx.save_for_backward(w, u, k, v, y)
 | 
				
			||||||
            return y
 | 
					            return y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @staticmethod
 | 
					        @staticmethod
 | 
				
			||||||
        def backward(ctx, gy):
 | 
					        def backward(ctx, gy):
 | 
				
			||||||
            B = ctx.B
 | 
					            B = ctx.B
 | 
				
			||||||
@ -78,16 +106,54 @@ if os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
				
			|||||||
            assert T <= T_MAX
 | 
					            assert T <= T_MAX
 | 
				
			||||||
            assert B * C % min(C, 32) == 0
 | 
					            assert B * C % min(C, 32) == 0
 | 
				
			||||||
            w, u, k, v, y = ctx.saved_tensors
 | 
					            w, u, k, v, y = ctx.saved_tensors
 | 
				
			||||||
            gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
 | 
					            gw = torch.empty(
 | 
				
			||||||
            gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
 | 
					                (B, C),
 | 
				
			||||||
            gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
 | 
					                device=gy.device,
 | 
				
			||||||
            gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            gu = torch.empty(
 | 
				
			||||||
 | 
					                (B, C),
 | 
				
			||||||
 | 
					                device=gy.device,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            gk = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C),
 | 
				
			||||||
 | 
					                device=gy.device,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            gv = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C),
 | 
				
			||||||
 | 
					                device=gy.device,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
 | 
					            wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
 | 
				
			||||||
            gw = torch.sum(gw, dim=0)
 | 
					            gw = torch.sum(gw, dim=0)
 | 
				
			||||||
            gu = torch.sum(gu, dim=0)
 | 
					            gu = torch.sum(gu, dim=0)
 | 
				
			||||||
            return (None, None, None, gw, gu, gk, gv)
 | 
					            return (None, None, None, gw, gu, gk, gv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
else:
 | 
					else:
 | 
				
			||||||
    wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["finetune/lora/cuda/wkv_op.cpp", "finetune/lora/cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
 | 
					    wkv_cuda = load(
 | 
				
			||||||
 | 
					        name=f"wkv_{T_MAX}",
 | 
				
			||||||
 | 
					        sources=[
 | 
				
			||||||
 | 
					            "finetune/lora/v4/cuda/wkv_op.cpp",
 | 
				
			||||||
 | 
					            "finetune/lora/v4/cuda/wkv_cuda.cu",
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        verbose=True,
 | 
				
			||||||
 | 
					        extra_cuda_cflags=[
 | 
				
			||||||
 | 
					            "-res-usage",
 | 
				
			||||||
 | 
					            "--maxrregcount 60",
 | 
				
			||||||
 | 
					            "--use_fast_math",
 | 
				
			||||||
 | 
					            "-O3",
 | 
				
			||||||
 | 
					            "-Xptxas -O3",
 | 
				
			||||||
 | 
					            "--extra-device-vectorization",
 | 
				
			||||||
 | 
					            f"-DTmax={T_MAX}",
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class WKV(torch.autograd.Function):
 | 
					    class WKV(torch.autograd.Function):
 | 
				
			||||||
        @staticmethod
 | 
					        @staticmethod
 | 
				
			||||||
        def forward(ctx, B, T, C, w, u, k, v):
 | 
					        def forward(ctx, B, T, C, w, u, k, v):
 | 
				
			||||||
@ -106,7 +172,9 @@ else:
 | 
				
			|||||||
                u = u.float().contiguous()
 | 
					                u = u.float().contiguous()
 | 
				
			||||||
                k = k.float().contiguous()
 | 
					                k = k.float().contiguous()
 | 
				
			||||||
                v = v.float().contiguous()
 | 
					                v = v.float().contiguous()
 | 
				
			||||||
            y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
 | 
					            y = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C), device=w.device, memory_format=torch.contiguous_format
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            wkv_cuda.forward(B, T, C, w, u, k, v, y)
 | 
					            wkv_cuda.forward(B, T, C, w, u, k, v, y)
 | 
				
			||||||
            ctx.save_for_backward(w, u, k, v, y)
 | 
					            ctx.save_for_backward(w, u, k, v, y)
 | 
				
			||||||
            if "32" in os.environ["RWKV_FLOAT_MODE"]:
 | 
					            if "32" in os.environ["RWKV_FLOAT_MODE"]:
 | 
				
			||||||
@ -115,6 +183,7 @@ else:
 | 
				
			|||||||
                return y.half()
 | 
					                return y.half()
 | 
				
			||||||
            elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
					            elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
				
			||||||
                return y.bfloat16()
 | 
					                return y.bfloat16()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @staticmethod
 | 
					        @staticmethod
 | 
				
			||||||
        def backward(ctx, gy):
 | 
					        def backward(ctx, gy):
 | 
				
			||||||
            B = ctx.B
 | 
					            B = ctx.B
 | 
				
			||||||
@ -123,14 +192,26 @@ else:
 | 
				
			|||||||
            assert T <= T_MAX
 | 
					            assert T <= T_MAX
 | 
				
			||||||
            assert B * C % min(C, 32) == 0
 | 
					            assert B * C % min(C, 32) == 0
 | 
				
			||||||
            w, u, k, v, y = ctx.saved_tensors
 | 
					            w, u, k, v, y = ctx.saved_tensors
 | 
				
			||||||
            gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
 | 
					            gw = torch.empty(
 | 
				
			||||||
            gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
 | 
					                (B, C), device=gy.device, memory_format=torch.contiguous_format
 | 
				
			||||||
            gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
 | 
					            )
 | 
				
			||||||
            gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
 | 
					            gu = torch.empty(
 | 
				
			||||||
 | 
					                (B, C), device=gy.device, memory_format=torch.contiguous_format
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            gk = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C), device=gy.device, memory_format=torch.contiguous_format
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            gv = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C), device=gy.device, memory_format=torch.contiguous_format
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            if "32" in os.environ["RWKV_FLOAT_MODE"]:
 | 
					            if "32" in os.environ["RWKV_FLOAT_MODE"]:
 | 
				
			||||||
                wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
 | 
					                wkv_cuda.backward(
 | 
				
			||||||
 | 
					                    B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
 | 
					                wkv_cuda.backward(
 | 
				
			||||||
 | 
					                    B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
            gw = torch.sum(gw, dim=0)
 | 
					            gw = torch.sum(gw, dim=0)
 | 
				
			||||||
            gu = torch.sum(gu, dim=0)
 | 
					            gu = torch.sum(gu, dim=0)
 | 
				
			||||||
            if "32" in os.environ["RWKV_FLOAT_MODE"]:
 | 
					            if "32" in os.environ["RWKV_FLOAT_MODE"]:
 | 
				
			||||||
@ -138,7 +219,15 @@ else:
 | 
				
			|||||||
            elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
 | 
					            elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
 | 
				
			||||||
                return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
 | 
					                return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
 | 
				
			||||||
            elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
					            elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
				
			||||||
                return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
 | 
					                return (
 | 
				
			||||||
 | 
					                    None,
 | 
				
			||||||
 | 
					                    None,
 | 
				
			||||||
 | 
					                    None,
 | 
				
			||||||
 | 
					                    gw.bfloat16(),
 | 
				
			||||||
 | 
					                    gu.bfloat16(),
 | 
				
			||||||
 | 
					                    gk.bfloat16(),
 | 
				
			||||||
 | 
					                    gv.bfloat16(),
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def RUN_CUDA(B, T, C, w, u, k, v):
 | 
					def RUN_CUDA(B, T, C, w, u, k, v):
 | 
				
			||||||
@ -151,15 +240,17 @@ def RUN_CUDA(B, T, C, w, u, k, v):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LoraLinear(nn.Module):
 | 
					class LoraLinear(nn.Module):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, in_features: int, out_features: int, bias: bool):
 | 
					    def __init__(self, in_features: int, out_features: int, bias: bool):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.weight = nn.Parameter(torch.empty((out_features, in_features)))
 | 
					        self.weight = nn.Parameter(torch.empty((out_features, in_features)))
 | 
				
			||||||
        assert bias == False, "Biased LoraLinear not supported"
 | 
					        assert bias == False, "Biased LoraLinear not supported"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        r, alpha, dropout = LORA_CONFIG["r"], LORA_CONFIG[
 | 
					        r, alpha, dropout = (
 | 
				
			||||||
            "alpha"], LORA_CONFIG["dropout"]
 | 
					            LORA_CONFIG["r"],
 | 
				
			||||||
 | 
					            LORA_CONFIG["alpha"],
 | 
				
			||||||
 | 
					            LORA_CONFIG["dropout"],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.lora_A = nn.Parameter(torch.empty(r, in_features))
 | 
					        self.lora_A = nn.Parameter(torch.empty(r, in_features))
 | 
				
			||||||
        self.lora_B = nn.Parameter(torch.empty(out_features, r))
 | 
					        self.lora_B = nn.Parameter(torch.empty(out_features, r))
 | 
				
			||||||
        self.lora_dropout = nn.Dropout(dropout)
 | 
					        self.lora_dropout = nn.Dropout(dropout)
 | 
				
			||||||
@ -170,9 +261,9 @@ class LoraLinear(nn.Module):
 | 
				
			|||||||
        nn.init.zeros_(self.lora_B)
 | 
					        nn.init.zeros_(self.lora_B)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        return (
 | 
					        return F.linear(x, self.weight) + self.scaling * F.linear(
 | 
				
			||||||
            F.linear(x, self.weight) + self.scaling *
 | 
					            F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
 | 
				
			||||||
            F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B))
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@functools.wraps(LoraLinear)
 | 
					@functools.wraps(LoraLinear)
 | 
				
			||||||
@ -214,17 +305,23 @@ class RWKV_TimeMix(MyModule):
 | 
				
			|||||||
            # fancy time_decay
 | 
					            # fancy time_decay
 | 
				
			||||||
            decay_speed = torch.ones(args.dim_att)
 | 
					            decay_speed = torch.ones(args.dim_att)
 | 
				
			||||||
            for h in range(args.dim_att):
 | 
					            for h in range(args.dim_att):
 | 
				
			||||||
                decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
 | 
					                decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (
 | 
				
			||||||
 | 
					                    0.7 + 1.3 * ratio_0_to_1
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
            self.time_decay = nn.Parameter(decay_speed)
 | 
					            self.time_decay = nn.Parameter(decay_speed)
 | 
				
			||||||
            # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
 | 
					            # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # fancy time_first
 | 
					            # fancy time_first
 | 
				
			||||||
            zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
 | 
					            zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
 | 
				
			||||||
            self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
 | 
					            self.time_first = nn.Parameter(
 | 
				
			||||||
 | 
					                torch.ones(args.dim_att) * math.log(0.3) + zigzag
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # fancy time_mix
 | 
					            # fancy time_mix
 | 
				
			||||||
            self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
 | 
					            self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
 | 
				
			||||||
            self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
 | 
					            self.time_mix_v = nn.Parameter(
 | 
				
			||||||
 | 
					                torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
 | 
					            self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
					        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
				
			||||||
@ -235,8 +332,10 @@ class RWKV_TimeMix(MyModule):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
 | 
					        self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if 'a' in os.environ["RWKV_MY_TESTING"]:
 | 
					        if "a" in os.environ["RWKV_MY_TESTING"]:
 | 
				
			||||||
            self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
 | 
					            self.register_buffer(
 | 
				
			||||||
 | 
					                "att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            d_qkv = args.n_embd // 16
 | 
					            d_qkv = args.n_embd // 16
 | 
				
			||||||
            self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
 | 
					            self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
 | 
				
			||||||
            self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
 | 
					            self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
 | 
				
			||||||
@ -245,12 +344,17 @@ class RWKV_TimeMix(MyModule):
 | 
				
			|||||||
            with torch.no_grad():
 | 
					            with torch.no_grad():
 | 
				
			||||||
                self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
 | 
					                self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
 | 
				
			||||||
                self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
 | 
					                self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
 | 
				
			||||||
                self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
 | 
					                self.time_mix_vv = nn.Parameter(
 | 
				
			||||||
 | 
					                    torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if "a" not in os.environ["RWKV_MY_TESTING"]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if 'a' not in os.environ["RWKV_MY_TESTING"]:
 | 
					 | 
				
			||||||
        @MyFunction
 | 
					        @MyFunction
 | 
				
			||||||
        def jit_func(self, x):
 | 
					        def jit_func(self, x):
 | 
				
			||||||
            xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
 | 
					            xx = self.time_shift(
 | 
				
			||||||
 | 
					                x
 | 
				
			||||||
 | 
					            )  # Mix x with the previous timestep to produce xk, xv, xr
 | 
				
			||||||
            xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
 | 
					            xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
 | 
				
			||||||
            xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
 | 
					            xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
 | 
				
			||||||
            xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
 | 
					            xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
 | 
				
			||||||
@ -263,21 +367,26 @@ class RWKV_TimeMix(MyModule):
 | 
				
			|||||||
        def forward(self, x):
 | 
					        def forward(self, x):
 | 
				
			||||||
            B, T, C = x.size()  # x = (Batch,Time,Channel)
 | 
					            B, T, C = x.size()  # x = (Batch,Time,Channel)
 | 
				
			||||||
            sr, k, v = self.jit_func(x)
 | 
					            sr, k, v = self.jit_func(x)
 | 
				
			||||||
            rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
 | 
					            rwkv = sr * RUN_CUDA(
 | 
				
			||||||
 | 
					                B, T, self.args.dim_att, self.time_decay, self.time_first, k, v
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            return self.output(rwkv)
 | 
					            return self.output(rwkv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if 'a' in os.environ["RWKV_MY_TESTING"]:
 | 
					    if "a" in os.environ["RWKV_MY_TESTING"]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @MyFunction
 | 
					        @MyFunction
 | 
				
			||||||
        def QKV(self, q, k, v):
 | 
					        def QKV(self, q, k, v):
 | 
				
			||||||
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
 | 
					            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
 | 
				
			||||||
            att = att.masked_fill(self.att_mask == 0, float('-inf'))
 | 
					            att = att.masked_fill(self.att_mask == 0, float("-inf"))
 | 
				
			||||||
            att = F.softmax(att, dim = -1)
 | 
					            att = F.softmax(att, dim=-1)
 | 
				
			||||||
            x = att @ v
 | 
					            x = att @ v
 | 
				
			||||||
            return x
 | 
					            return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @MyFunction
 | 
					        @MyFunction
 | 
				
			||||||
        def jit_funcQKV(self, x):
 | 
					        def jit_funcQKV(self, x):
 | 
				
			||||||
            xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
 | 
					            xx = self.time_shift(
 | 
				
			||||||
 | 
					                x
 | 
				
			||||||
 | 
					            )  # Mix x with the previous timestep to produce xk, xv, xr
 | 
				
			||||||
            xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
 | 
					            xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
 | 
				
			||||||
            xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
 | 
					            xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
 | 
				
			||||||
            xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
 | 
					            xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
 | 
				
			||||||
@ -296,12 +405,16 @@ class RWKV_TimeMix(MyModule):
 | 
				
			|||||||
        def forward(self, x):
 | 
					        def forward(self, x):
 | 
				
			||||||
            B, T, C = x.size()  # x = (Batch,Time,Channel)
 | 
					            B, T, C = x.size()  # x = (Batch,Time,Channel)
 | 
				
			||||||
            sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
 | 
					            sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
 | 
				
			||||||
            rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
 | 
					            rwkv = sr * RUN_CUDA(
 | 
				
			||||||
 | 
					                B, T, self.args.dim_att, self.time_decay, self.time_first, k, v
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
 | 
					            rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
 | 
				
			||||||
            return rwkv
 | 
					            return rwkv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
########################################################################################################
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RWKV_ChannelMix(MyModule):
 | 
					class RWKV_ChannelMix(MyModule):
 | 
				
			||||||
    def __init__(self, args, layer_id):
 | 
					    def __init__(self, args, layer_id):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
@ -331,6 +444,7 @@ class RWKV_ChannelMix(MyModule):
 | 
				
			|||||||
        kv = self.value(k)
 | 
					        kv = self.value(k)
 | 
				
			||||||
        return torch.sigmoid(self.receptance(xr)) * kv
 | 
					        return torch.sigmoid(self.receptance(xr)) * kv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MishGLU(MyModule):
 | 
					class MishGLU(MyModule):
 | 
				
			||||||
    def __init__(self, args, layer_id):
 | 
					    def __init__(self, args, layer_id):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
@ -360,6 +474,7 @@ class MishGLU(MyModule):
 | 
				
			|||||||
        b = self.bb(xb)
 | 
					        b = self.bb(xb)
 | 
				
			||||||
        return self.value(a * F.mish(b))
 | 
					        return self.value(a * F.mish(b))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
########################################################################################################
 | 
					########################################################################################################
 | 
				
			||||||
# The RWKV Model with our blocks
 | 
					# The RWKV Model with our blocks
 | 
				
			||||||
########################################################################################################
 | 
					########################################################################################################
 | 
				
			||||||
@ -377,15 +492,19 @@ class Block(nn.Module):
 | 
				
			|||||||
        if self.layer_id == 0:
 | 
					        if self.layer_id == 0:
 | 
				
			||||||
            self.ln0 = nn.LayerNorm(args.n_embd)
 | 
					            self.ln0 = nn.LayerNorm(args.n_embd)
 | 
				
			||||||
            if args.my_pos_emb > 0:
 | 
					            if args.my_pos_emb > 0:
 | 
				
			||||||
                self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
 | 
					                self.pos_emb_x = nn.Parameter(
 | 
				
			||||||
                self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
 | 
					                    torch.zeros((1, args.my_pos_emb, args.n_embd))
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                self.pos_emb_y = nn.Parameter(
 | 
				
			||||||
 | 
					                    torch.zeros((args.my_pos_emb, 1, args.n_embd))
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.layer_id == 0 and self.args.pre_ffn > 0:
 | 
					        if self.layer_id == 0 and self.args.pre_ffn > 0:
 | 
				
			||||||
            self.ffnPre = RWKV_ChannelMix(args, 0)
 | 
					            self.ffnPre = RWKV_ChannelMix(args, 0)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.att = RWKV_TimeMix(args, layer_id)
 | 
					            self.att = RWKV_TimeMix(args, layer_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if 'g' in os.environ["RWKV_MY_TESTING"]:
 | 
					        if "g" in os.environ["RWKV_MY_TESTING"]:
 | 
				
			||||||
            self.ffn = MishGLU(args, layer_id)
 | 
					            self.ffn = MishGLU(args, layer_id)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.ffn = RWKV_ChannelMix(args, layer_id)
 | 
					            self.ffn = RWKV_ChannelMix(args, layer_id)
 | 
				
			||||||
@ -395,7 +514,9 @@ class Block(nn.Module):
 | 
				
			|||||||
            self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
 | 
					            self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
 | 
				
			||||||
            self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
 | 
					            self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
 | 
				
			||||||
            self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
 | 
					            self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
 | 
				
			||||||
            self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
 | 
					            self.register_buffer(
 | 
				
			||||||
 | 
					                "tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x, x_emb=None):
 | 
					    def forward(self, x, x_emb=None):
 | 
				
			||||||
        args = self.args
 | 
					        args = self.args
 | 
				
			||||||
@ -403,7 +524,7 @@ class Block(nn.Module):
 | 
				
			|||||||
        if self.layer_id == 0:
 | 
					        if self.layer_id == 0:
 | 
				
			||||||
            x = self.ln0(x)
 | 
					            x = self.ln0(x)
 | 
				
			||||||
            if args.my_pos_emb > 0:
 | 
					            if args.my_pos_emb > 0:
 | 
				
			||||||
                pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
 | 
					                pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :]
 | 
				
			||||||
                x = x + pos_emb
 | 
					                x = x + pos_emb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.layer_id == 0 and args.pre_ffn > 0:
 | 
					        if self.layer_id == 0 and args.pre_ffn > 0:
 | 
				
			||||||
@ -443,13 +564,13 @@ class RWKV(pl.LightningModule):
 | 
				
			|||||||
    def __init__(self, args):
 | 
					    def __init__(self, args):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.args = args
 | 
					        self.args = args
 | 
				
			||||||
        if not hasattr(args, 'dim_att'):
 | 
					        if not hasattr(args, "dim_att"):
 | 
				
			||||||
            args.dim_att = args.n_embd
 | 
					            args.dim_att = args.n_embd
 | 
				
			||||||
        if not hasattr(args, 'dim_ffn'):
 | 
					        if not hasattr(args, "dim_ffn"):
 | 
				
			||||||
            args.dim_ffn = args.n_embd * 4
 | 
					            args.dim_ffn = args.n_embd * 4
 | 
				
			||||||
        if not hasattr(args, 'tiny_att_layer'):
 | 
					        if not hasattr(args, "tiny_att_layer"):
 | 
				
			||||||
            args.tiny_att_layer = -1
 | 
					            args.tiny_att_layer = -1
 | 
				
			||||||
        if not hasattr(args, 'tiny_att_dim'):
 | 
					        if not hasattr(args, "tiny_att_dim"):
 | 
				
			||||||
            args.tiny_att_dim = -1
 | 
					            args.tiny_att_dim = -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.emb = nn.Embedding(args.vocab_size, args.n_embd)
 | 
					        self.emb = nn.Embedding(args.vocab_size, args.n_embd)
 | 
				
			||||||
@ -462,7 +583,9 @@ class RWKV(pl.LightningModule):
 | 
				
			|||||||
        if args.head_qk > 0:
 | 
					        if args.head_qk > 0:
 | 
				
			||||||
            self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
 | 
					            self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
 | 
				
			||||||
            self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
 | 
					            self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
 | 
				
			||||||
            self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
 | 
					            self.register_buffer(
 | 
				
			||||||
 | 
					                "copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def configure_optimizers(self):
 | 
					    def configure_optimizers(self):
 | 
				
			||||||
        args = self.args
 | 
					        args = self.args
 | 
				
			||||||
@ -494,19 +617,46 @@ class RWKV(pl.LightningModule):
 | 
				
			|||||||
            param_dict = {n: p for n, p in self.named_parameters()}
 | 
					            param_dict = {n: p for n, p in self.named_parameters()}
 | 
				
			||||||
            if args.my_pile_stage == 2:
 | 
					            if args.my_pile_stage == 2:
 | 
				
			||||||
                optim_groups = [
 | 
					                optim_groups = [
 | 
				
			||||||
                    {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
 | 
					                    {
 | 
				
			||||||
                    {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
 | 
					                        "params": [param_dict[n] for n in lr_1x],
 | 
				
			||||||
                    {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 1.0,
 | 
				
			||||||
 | 
					                    },
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_2x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 5.0,
 | 
				
			||||||
 | 
					                    },  # test: 2e-3 / args.lr_init},
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_3x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 5.0,
 | 
				
			||||||
 | 
					                    },  # test: 3e-3 / args.lr_init},
 | 
				
			||||||
                ]
 | 
					                ]
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                optim_groups = [
 | 
					                optim_groups = [
 | 
				
			||||||
                    {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
 | 
					                    {
 | 
				
			||||||
                    {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
 | 
					                        "params": [param_dict[n] for n in lr_1x],
 | 
				
			||||||
                    {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 1.0,
 | 
				
			||||||
 | 
					                    },
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_2x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 2.0,
 | 
				
			||||||
 | 
					                    },
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_3x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 3.0,
 | 
				
			||||||
 | 
					                    },
 | 
				
			||||||
                ]
 | 
					                ]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            optim_groups = [
 | 
					            optim_groups = [
 | 
				
			||||||
                {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
 | 
					                {
 | 
				
			||||||
 | 
					                    "params": [p for n, p in self.named_parameters()],
 | 
				
			||||||
 | 
					                    "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
            ]
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for g in optim_groups:
 | 
					        for g in optim_groups:
 | 
				
			||||||
@ -514,8 +664,26 @@ class RWKV(pl.LightningModule):
 | 
				
			|||||||
        optim_groups = [g for g in optim_groups if len(g["params"]) > 0]
 | 
					        optim_groups = [g for g in optim_groups if len(g["params"]) > 0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.deepspeed_offload:
 | 
					        if self.deepspeed_offload:
 | 
				
			||||||
            return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
 | 
					            return DeepSpeedCPUAdam(
 | 
				
			||||||
        return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
 | 
					                optim_groups,
 | 
				
			||||||
 | 
					                lr=self.args.lr_init,
 | 
				
			||||||
 | 
					                betas=self.args.betas,
 | 
				
			||||||
 | 
					                eps=self.args.adam_eps,
 | 
				
			||||||
 | 
					                bias_correction=True,
 | 
				
			||||||
 | 
					                adamw_mode=False,
 | 
				
			||||||
 | 
					                weight_decay=0,
 | 
				
			||||||
 | 
					                amsgrad=False,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        return FusedAdam(
 | 
				
			||||||
 | 
					            optim_groups,
 | 
				
			||||||
 | 
					            lr=self.args.lr_init,
 | 
				
			||||||
 | 
					            betas=self.args.betas,
 | 
				
			||||||
 | 
					            eps=self.args.adam_eps,
 | 
				
			||||||
 | 
					            bias_correction=True,
 | 
				
			||||||
 | 
					            adam_w_mode=False,
 | 
				
			||||||
 | 
					            weight_decay=0,
 | 
				
			||||||
 | 
					            amsgrad=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
 | 
					        # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
@ -589,10 +757,14 @@ class RWKV(pl.LightningModule):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            logits = self(idx)
 | 
					            logits = self(idx)
 | 
				
			||||||
            if sum_mask == mask.shape[0]:
 | 
					            if sum_mask == mask.shape[0]:
 | 
				
			||||||
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
 | 
					                loss = F.cross_entropy(
 | 
				
			||||||
 | 
					                    logits.view(-1, logits.size(-1)), targets.view(-1)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
                # print('rank', self.global_rank, 'loss', loss.item())
 | 
					                # print('rank', self.global_rank, 'loss', loss.item())
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
 | 
					                loss = F.cross_entropy(
 | 
				
			||||||
 | 
					                    logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
                # loss_raw = loss
 | 
					                # loss_raw = loss
 | 
				
			||||||
                loss = torch.sum(loss * mask) / sum_mask
 | 
					                loss = torch.sum(loss * mask) / sum_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -632,7 +804,14 @@ class RWKV(pl.LightningModule):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            gain = 1.0
 | 
					            gain = 1.0
 | 
				
			||||||
            scale = 1.0
 | 
					            scale = 1.0
 | 
				
			||||||
            if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
 | 
					            if (
 | 
				
			||||||
 | 
					                "ln_" in n
 | 
				
			||||||
 | 
					                or ".ln" in n
 | 
				
			||||||
 | 
					                or "time_" in n
 | 
				
			||||||
 | 
					                or "_mask" in n
 | 
				
			||||||
 | 
					                or "pos_emb" in n
 | 
				
			||||||
 | 
					                or ".mask." in n
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
                m[n] = p
 | 
					                m[n] = p
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                if n == "emb.weight":
 | 
					                if n == "emb.weight":
 | 
				
			||||||
@ -640,7 +819,19 @@ class RWKV(pl.LightningModule):
 | 
				
			|||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    if shape[0] > shape[1]:
 | 
					                    if shape[0] > shape[1]:
 | 
				
			||||||
                        gain = math.sqrt(shape[0] / shape[1])
 | 
					                        gain = math.sqrt(shape[0] / shape[1])
 | 
				
			||||||
                    for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
 | 
					                    for kk in [
 | 
				
			||||||
 | 
					                        ".att.key.",
 | 
				
			||||||
 | 
					                        ".att.receptance.",
 | 
				
			||||||
 | 
					                        ".att.output.",
 | 
				
			||||||
 | 
					                        ".att.key.",
 | 
				
			||||||
 | 
					                        ".ffn.value.",
 | 
				
			||||||
 | 
					                        ".ffn.receptance.",
 | 
				
			||||||
 | 
					                        ".ffnPre.value.",
 | 
				
			||||||
 | 
					                        ".ffnPre.receptance.",
 | 
				
			||||||
 | 
					                        "head_q.",
 | 
				
			||||||
 | 
					                        ".oo.",
 | 
				
			||||||
 | 
					                        ".rr.",
 | 
				
			||||||
 | 
					                    ]:
 | 
				
			||||||
                        if kk in n:
 | 
					                        if kk in n:
 | 
				
			||||||
                            scale = 0
 | 
					                            scale = 0
 | 
				
			||||||
                    if n == "head.weight":
 | 
					                    if n == "head.weight":
 | 
				
			||||||
@ -650,7 +841,9 @@ class RWKV(pl.LightningModule):
 | 
				
			|||||||
                    if "head_q." in n:
 | 
					                    if "head_q." in n:
 | 
				
			||||||
                        scale = 0
 | 
					                        scale = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
 | 
					                print(
 | 
				
			||||||
 | 
					                    f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if self.args.accelerator.upper() == "GPU":
 | 
					                if self.args.accelerator.upper() == "GPU":
 | 
				
			||||||
                    m[n] = torch.empty((shape[0], shape[1]), device="cuda")
 | 
					                    m[n] = torch.empty((shape[0], shape[1]), device="cuda")
 | 
				
			||||||
@ -5,15 +5,17 @@ import pytorch_lightning as pl
 | 
				
			|||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
 | 
					from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
 | 
				
			||||||
from .model import LORA_CONFIG
 | 
					from .model import LORA_CONFIG
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def my_save(dd, ff):
 | 
					def my_save(dd, ff):
 | 
				
			||||||
    if '14b-run1' not in ff:
 | 
					    if "14b-run1" not in ff:
 | 
				
			||||||
        torch.save(dd, ff)
 | 
					        torch.save(dd, ff)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        fn = ff.split('/')[-1]
 | 
					        fn = ff.split("/")[-1]
 | 
				
			||||||
        fff = '/dev/shm/' + fn
 | 
					        fff = "/dev/shm/" + fn
 | 
				
			||||||
        torch.save(dd, fff)
 | 
					        torch.save(dd, fff)
 | 
				
			||||||
        subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
 | 
					        subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class train_callback(pl.Callback):
 | 
					class train_callback(pl.Callback):
 | 
				
			||||||
    def __init__(self, args):
 | 
					    def __init__(self, args):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
@ -38,7 +40,9 @@ class train_callback(pl.Callback):
 | 
				
			|||||||
            if args.lr_final == 0 or args.lr_init == 0:  # linear decay
 | 
					            if args.lr_final == 0 or args.lr_init == 0:  # linear decay
 | 
				
			||||||
                lr = args.lr_init + (args.lr_final - args.lr_init) * progress
 | 
					                lr = args.lr_init + (args.lr_final - args.lr_init) * progress
 | 
				
			||||||
            else:  # exp decay
 | 
					            else:  # exp decay
 | 
				
			||||||
                lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
 | 
					                lr = args.lr_init * math.exp(
 | 
				
			||||||
 | 
					                    math.log(args.lr_final / args.lr_init) * pow(progress, 1)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if trainer.global_step < w_step:
 | 
					            if trainer.global_step < w_step:
 | 
				
			||||||
                lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
 | 
					                lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
 | 
				
			||||||
@ -60,7 +64,9 @@ class train_callback(pl.Callback):
 | 
				
			|||||||
                trainer.my_loss_sum = 0
 | 
					                trainer.my_loss_sum = 0
 | 
				
			||||||
                trainer.my_loss_count = 0
 | 
					                trainer.my_loss_count = 0
 | 
				
			||||||
                trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
 | 
					                trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
 | 
				
			||||||
                trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
 | 
					                trainer.my_log.write(
 | 
				
			||||||
 | 
					                    f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    print(f"\n{trainer.strategy.config}\n")
 | 
					                    print(f"\n{trainer.strategy.config}\n")
 | 
				
			||||||
                    trainer.my_log.write(f"{trainer.strategy.config}\n")
 | 
					                    trainer.my_log.write(f"{trainer.strategy.config}\n")
 | 
				
			||||||
@ -70,6 +76,7 @@ class train_callback(pl.Callback):
 | 
				
			|||||||
                if len(args.wandb) > 0:
 | 
					                if len(args.wandb) > 0:
 | 
				
			||||||
                    print("Login to wandb...")
 | 
					                    print("Login to wandb...")
 | 
				
			||||||
                    import wandb
 | 
					                    import wandb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    wandb.init(
 | 
					                    wandb.init(
 | 
				
			||||||
                        project=args.wandb,
 | 
					                        project=args.wandb,
 | 
				
			||||||
                        name=args.run_name + " " + args.my_timestamp,
 | 
					                        name=args.run_name + " " + args.my_timestamp,
 | 
				
			||||||
@ -102,20 +109,26 @@ class train_callback(pl.Callback):
 | 
				
			|||||||
            # self.log("s", real_step, prog_bar=True, on_step=True)
 | 
					            # self.log("s", real_step, prog_bar=True, on_step=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if len(args.wandb) > 0:
 | 
					            if len(args.wandb) > 0:
 | 
				
			||||||
                lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
 | 
					                lll = {
 | 
				
			||||||
 | 
					                    "loss": trainer.my_loss,
 | 
				
			||||||
 | 
					                    "lr": trainer.my_lr,
 | 
				
			||||||
 | 
					                    "Gtokens": real_step * token_per_step / 1e9,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
                if kt_s > 0:
 | 
					                if kt_s > 0:
 | 
				
			||||||
                    lll["kt/s"] = kt_s
 | 
					                    lll["kt/s"] = kt_s
 | 
				
			||||||
                trainer.my_wandb.log(lll, step=int(real_step))
 | 
					                trainer.my_wandb.log(lll, step=int(real_step))
 | 
				
			||||||
            if args.magic_prime > 0:
 | 
					            if args.magic_prime > 0:
 | 
				
			||||||
                expand_factor = 2 if args.my_qa_mask > 0 else 1
 | 
					                expand_factor = 2 if args.my_qa_mask > 0 else 1
 | 
				
			||||||
                if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
 | 
					                if (
 | 
				
			||||||
 | 
					                    int(real_step)
 | 
				
			||||||
 | 
					                    == int(args.magic_prime * expand_factor // args.real_bsz) - 1
 | 
				
			||||||
 | 
					                ):
 | 
				
			||||||
                    to_save_dict = pl_module.state_dict()
 | 
					                    to_save_dict = pl_module.state_dict()
 | 
				
			||||||
                    my_save(
 | 
					                    my_save(
 | 
				
			||||||
                        to_save_dict,
 | 
					                        to_save_dict,
 | 
				
			||||||
                        f"{args.proj_dir}/rwkv-final.pth",
 | 
					                        f"{args.proj_dir}/rwkv-final.pth",
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def on_train_epoch_start(self, trainer, pl_module):
 | 
					    def on_train_epoch_start(self, trainer, pl_module):
 | 
				
			||||||
        args = self.args
 | 
					        args = self.args
 | 
				
			||||||
        dataset = trainer.train_dataloader.dataset.datasets
 | 
					        dataset = trainer.train_dataloader.dataset.datasets
 | 
				
			||||||
@ -128,24 +141,28 @@ class train_callback(pl.Callback):
 | 
				
			|||||||
    def on_train_epoch_end(self, trainer, pl_module):
 | 
					    def on_train_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        args = self.args
 | 
					        args = self.args
 | 
				
			||||||
        if trainer.is_global_zero:  # logging & save state_dict
 | 
					        if trainer.is_global_zero:  # logging & save state_dict
 | 
				
			||||||
            if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
 | 
					            if (
 | 
				
			||||||
                if args.data_type == 'wds_img':
 | 
					                args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0
 | 
				
			||||||
 | 
					            ) or trainer.current_epoch == args.epoch_count - 1:
 | 
				
			||||||
 | 
					                if args.data_type == "wds_img":
 | 
				
			||||||
                    raw_dict = pl_module.state_dict()
 | 
					                    raw_dict = pl_module.state_dict()
 | 
				
			||||||
                    to_save_dict = {}
 | 
					                    to_save_dict = {}
 | 
				
			||||||
                    for k in raw_dict:
 | 
					                    for k in raw_dict:
 | 
				
			||||||
                        if k.startswith('encoder.') or k.startswith('decoder.'):
 | 
					                        if k.startswith("encoder.") or k.startswith("decoder."):
 | 
				
			||||||
                            to_save_dict[k] = raw_dict[k]
 | 
					                            to_save_dict[k] = raw_dict[k]
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    to_save_dict = pl_module.state_dict()
 | 
					                    to_save_dict = pl_module.state_dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if args.lora:
 | 
					                if args.lora:
 | 
				
			||||||
                    enable_time_finetune = 'time' in LORA_CONFIG["parts"]
 | 
					                    enable_time_finetune = "time" in LORA_CONFIG["parts"]
 | 
				
			||||||
                    enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
 | 
					                    enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
 | 
				
			||||||
                    lora_dict = {}
 | 
					                    lora_dict = {}
 | 
				
			||||||
                    for name, state in to_save_dict.items():
 | 
					                    for name, state in to_save_dict.items():
 | 
				
			||||||
                        if ('.lora_' in name
 | 
					                        if (
 | 
				
			||||||
                                or (enable_time_finetune and '.time_' in name)
 | 
					                            ".lora_" in name
 | 
				
			||||||
                                or (enable_ln_finetune and '.ln' in name)):
 | 
					                            or (enable_time_finetune and ".time_" in name)
 | 
				
			||||||
 | 
					                            or (enable_ln_finetune and ".ln" in name)
 | 
				
			||||||
 | 
					                        ):
 | 
				
			||||||
                            lora_dict[name] = state
 | 
					                            lora_dict[name] = state
 | 
				
			||||||
                    to_save_dict = lora_dict
 | 
					                    to_save_dict = lora_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -155,8 +172,10 @@ class train_callback(pl.Callback):
 | 
				
			|||||||
                        f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
 | 
					                        f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                except Exception as e:
 | 
					                except Exception as e:
 | 
				
			||||||
                    print('Error\n\n', e, '\n\n')
 | 
					                    print("Error\n\n", e, "\n\n")
 | 
				
			||||||
            trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
 | 
					            trainer.my_log.write(
 | 
				
			||||||
 | 
					                f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
            trainer.my_log.flush()
 | 
					            trainer.my_log.flush()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            trainer.my_loss_sum = 0
 | 
					            trainer.my_loss_sum = 0
 | 
				
			||||||
@ -178,22 +197,22 @@ def generate_init_weight(model, init_weight_name):
 | 
				
			|||||||
                    mm[k] = src.reshape(mm[k].shape)
 | 
					                    mm[k] = src.reshape(mm[k].shape)
 | 
				
			||||||
                except:
 | 
					                except:
 | 
				
			||||||
                    tmp = mm[k].squeeze().clone()
 | 
					                    tmp = mm[k].squeeze().clone()
 | 
				
			||||||
                    print(k, src.shape, '-->', mm[k].shape)
 | 
					                    print(k, src.shape, "-->", mm[k].shape)
 | 
				
			||||||
                    ss = src.shape[0]
 | 
					                    ss = src.shape[0]
 | 
				
			||||||
                    dd = tmp.shape[0]
 | 
					                    dd = tmp.shape[0]
 | 
				
			||||||
                    for i in range(dd):
 | 
					                    for i in range(dd):
 | 
				
			||||||
                        pos = i / dd * ss
 | 
					                        pos = i / dd * ss
 | 
				
			||||||
                        if pos >= ss - 1:
 | 
					                        if pos >= ss - 1:
 | 
				
			||||||
                            tmp[i] = src[ss-1]
 | 
					                            tmp[i] = src[ss - 1]
 | 
				
			||||||
                        else:
 | 
					                        else:
 | 
				
			||||||
                            p0 = int(math.floor(pos))
 | 
					                            p0 = int(math.floor(pos))
 | 
				
			||||||
                            ii = pos - p0
 | 
					                            ii = pos - p0
 | 
				
			||||||
                            tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
 | 
					                            tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii)
 | 
				
			||||||
                    mm[k] = tmp.reshape(mm[k].shape)
 | 
					                    mm[k] = tmp.reshape(mm[k].shape)
 | 
				
			||||||
                    sss = src.squeeze().float().cpu().numpy()
 | 
					                    sss = src.squeeze().float().cpu().numpy()
 | 
				
			||||||
                    print(sss[:10], '...', sss[-10:])
 | 
					                    print(sss[:10], "...", sss[-10:])
 | 
				
			||||||
                    mmm = mm[k].squeeze().float().cpu().numpy()
 | 
					                    mmm = mm[k].squeeze().float().cpu().numpy()
 | 
				
			||||||
                    print(mmm[:10], '...', mmm[-10:])
 | 
					                    print(mmm[:10], "...", mmm[-10:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print(f"Save to {init_weight_name}...")
 | 
					    print(f"Save to {init_weight_name}...")
 | 
				
			||||||
    torch.save(mm, init_weight_name)
 | 
					    torch.save(mm, init_weight_name)
 | 
				
			||||||
@ -6,6 +6,7 @@ from torch.nn import functional as F
 | 
				
			|||||||
time_slot = {}
 | 
					time_slot = {}
 | 
				
			||||||
time_ref = time.time_ns()
 | 
					time_ref = time.time_ns()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def record_time(name):
 | 
					def record_time(name):
 | 
				
			||||||
    if name not in time_slot:
 | 
					    if name not in time_slot:
 | 
				
			||||||
        time_slot[name] = 1e20
 | 
					        time_slot[name] = 1e20
 | 
				
			||||||
@ -13,20 +14,23 @@ def record_time(name):
 | 
				
			|||||||
    if tt < time_slot[name]:
 | 
					    if tt < time_slot[name]:
 | 
				
			||||||
        time_slot[name] = tt
 | 
					        time_slot[name] = tt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TOKENIZER():
 | 
					
 | 
				
			||||||
    def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
 | 
					class TOKENIZER:
 | 
				
			||||||
        if 'list' in str(type(WORD_NAME)):
 | 
					    def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"):
 | 
				
			||||||
 | 
					        if "list" in str(type(WORD_NAME)):
 | 
				
			||||||
            self.charMode = False
 | 
					            self.charMode = False
 | 
				
			||||||
            if WORD_NAME[0] == WORD_NAME[1]:
 | 
					            if WORD_NAME[0] == WORD_NAME[1]:
 | 
				
			||||||
                from transformers import PreTrainedTokenizerFast
 | 
					                from transformers import PreTrainedTokenizerFast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
 | 
					                self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                from transformers import GPT2TokenizerFast
 | 
					                from transformers import GPT2TokenizerFast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
 | 
					                self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
 | 
				
			||||||
            self.vocab_size = len(self.tokenizer)
 | 
					            self.vocab_size = len(self.tokenizer)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.charMode = True
 | 
					            self.charMode = True
 | 
				
			||||||
            with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
 | 
					            with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file:
 | 
				
			||||||
                self.word_table = json.load(result_file)
 | 
					                self.word_table = json.load(result_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.vocab_size = len(self.word_table)
 | 
					            self.vocab_size = len(self.word_table)
 | 
				
			||||||
@ -37,23 +41,25 @@ class TOKENIZER():
 | 
				
			|||||||
            self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
 | 
					            self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def refine_context(self, context):
 | 
					    def refine_context(self, context):
 | 
				
			||||||
        context = context.strip().split('\n')
 | 
					        context = context.strip().split("\n")
 | 
				
			||||||
        for c in range(len(context)):
 | 
					        for c in range(len(context)):
 | 
				
			||||||
            context[c] = context[c].strip().strip('\u3000').strip('\r')
 | 
					            context[c] = context[c].strip().strip("\u3000").strip("\r")
 | 
				
			||||||
        context = list(filter(lambda c: c != '', context))
 | 
					        context = list(filter(lambda c: c != "", context))
 | 
				
			||||||
        context = '\n' + ('\n'.join(context)).strip()
 | 
					        context = "\n" + ("\n".join(context)).strip()
 | 
				
			||||||
        if context == '':
 | 
					        if context == "":
 | 
				
			||||||
            context = '\n'
 | 
					            context = "\n"
 | 
				
			||||||
        return context
 | 
					        return context
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
 | 
					    def sample_logits(
 | 
				
			||||||
 | 
					        self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
        # out[self.UNKNOWN_CHAR] = -float('Inf')
 | 
					        # out[self.UNKNOWN_CHAR] = -float('Inf')
 | 
				
			||||||
        lastChar = int(x[-1])
 | 
					        lastChar = int(x[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        probs = F.softmax(out, dim=-1)
 | 
					        probs = F.softmax(out, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.charMode:
 | 
					        if self.charMode:
 | 
				
			||||||
            if self.itos[lastChar] == '\n':
 | 
					            if self.itos[lastChar] == "\n":
 | 
				
			||||||
                top_p = top_p_newline
 | 
					                top_p = top_p_newline
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                top_p = top_p_usual
 | 
					                top_p = top_p_usual
 | 
				
			||||||
@ -81,6 +87,7 @@ class TOKENIZER():
 | 
				
			|||||||
            out = torch.multinomial(probs, num_samples=1)[0]
 | 
					            out = torch.multinomial(probs, num_samples=1)[0]
 | 
				
			||||||
            return out
 | 
					            return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def MaybeIsPrime(number):
 | 
					def MaybeIsPrime(number):
 | 
				
			||||||
    if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
 | 
					    if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
@ -121,7 +128,9 @@ def MillerRabinPrimalityTest(number):
 | 
				
			|||||||
        if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
 | 
					        if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
 | 
				
			||||||
            iterationNumber = 1
 | 
					            iterationNumber = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
 | 
					            while (iterationNumber <= timesTwoDividNumber - 1) and (
 | 
				
			||||||
 | 
					                randomNumberWithPower != number - 1
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
                randomNumberWithPower = pow(randomNumberWithPower, 2, number)
 | 
					                randomNumberWithPower = pow(randomNumberWithPower, 2, number)
 | 
				
			||||||
                iterationNumber = iterationNumber + 1
 | 
					                iterationNumber = iterationNumber + 1
 | 
				
			||||||
            if randomNumberWithPower != (number - 1):
 | 
					            if randomNumberWithPower != (number - 1):
 | 
				
			||||||
@ -184,7 +184,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    args.num_sanity_val_steps = 0
 | 
					    args.num_sanity_val_steps = 0
 | 
				
			||||||
    args.check_val_every_n_epoch = int(1e20)
 | 
					    args.check_val_every_n_epoch = int(1e20)
 | 
				
			||||||
    args.log_every_n_steps = int(1e20)
 | 
					    args.log_every_n_steps = int(1e20)
 | 
				
			||||||
    args.max_epochs = args.epoch_count  # continue forever
 | 
					    args.max_epochs = args.epoch_count  # -1 continue forever
 | 
				
			||||||
    args.betas = (args.beta1, args.beta2)
 | 
					    args.betas = (args.beta1, args.beta2)
 | 
				
			||||||
    args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
 | 
					    args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
 | 
				
			||||||
    os.environ["RWKV_T_MAX"] = str(args.ctx_len)
 | 
					    os.environ["RWKV_T_MAX"] = str(args.ctx_len)
 | 
				
			||||||
@ -373,7 +373,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
                    for param in module.parameters():
 | 
					                    for param in module.parameters():
 | 
				
			||||||
                        param.requires_grad = True
 | 
					                        param.requires_grad = True
 | 
				
			||||||
                elif enable_time_finetune and any(
 | 
					                elif enable_time_finetune and any(
 | 
				
			||||||
                        n.startswith("time") for n, _ in module.named_parameters()
 | 
					                    n.startswith("time") for n, _ in module.named_parameters()
 | 
				
			||||||
                ):
 | 
					                ):
 | 
				
			||||||
                    for pname, param in module.named_parameters():
 | 
					                    for pname, param in module.named_parameters():
 | 
				
			||||||
                        if pname.startswith("time"):
 | 
					                        if pname.startswith("time"):
 | 
				
			||||||
@ -381,7 +381,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
                            param.requires_grad = True
 | 
					                            param.requires_grad = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (
 | 
					    if (
 | 
				
			||||||
            len(args.load_model) == 0 or args.my_pile_stage == 1
 | 
					        len(args.load_model) == 0 or args.my_pile_stage == 1
 | 
				
			||||||
    ):  # shall we build the initial weights?
 | 
					    ):  # shall we build the initial weights?
 | 
				
			||||||
        init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
 | 
					        init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
 | 
				
			||||||
        generate_init_weight(model, init_weight_name)  # save initial weights
 | 
					        generate_init_weight(model, init_weight_name)  # save initial weights
 | 
				
			||||||
@ -423,8 +423,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (
 | 
					    if (
 | 
				
			||||||
            args.lr_init > 1e-4
 | 
					        args.lr_init > 1e-4
 | 
				
			||||||
            or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
 | 
					        or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        if "I_KNOW_WHAT_IM_DOING" in os.environ:
 | 
					        if "I_KNOW_WHAT_IM_DOING" in os.environ:
 | 
				
			||||||
            if trainer.global_rank == 0:
 | 
					            if trainer.global_rank == 0:
 | 
				
			||||||
@ -459,10 +459,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    if "deepspeed" in args.strategy:
 | 
					    if "deepspeed" in args.strategy:
 | 
				
			||||||
        trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
 | 
					        trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
 | 
				
			||||||
                args.ds_bucket_mb * 1000 * 1000
 | 
					            args.ds_bucket_mb * 1000 * 1000
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
 | 
					        trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
 | 
				
			||||||
                args.ds_bucket_mb * 1000 * 1000
 | 
					            args.ds_bucket_mb * 1000 * 1000
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # must set shuffle=False, persistent_workers=False (because worker is in another thread)
 | 
					    # must set shuffle=False, persistent_workers=False (because worker is in another thread)
 | 
				
			||||||
							
								
								
									
										202
									
								
								finetune/lora/v5/cuda/wkv5_cuda.cu
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										202
									
								
								finetune/lora/v5/cuda/wkv5_cuda.cu
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,202 @@
 | 
				
			|||||||
 | 
					#include <stdio.h>
 | 
				
			||||||
 | 
					#include <assert.h>
 | 
				
			||||||
 | 
					#include "ATen/ATen.h"
 | 
				
			||||||
 | 
					typedef at::BFloat16 bf16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename F>
 | 
				
			||||||
 | 
					__global__ void kernel_forward(const int B, const int T, const int C, const int H,
 | 
				
			||||||
 | 
					                               const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
 | 
				
			||||||
 | 
					                               F *__restrict__ const _y)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    const int b = blockIdx.x / H;
 | 
				
			||||||
 | 
					    const int h = blockIdx.x % H;
 | 
				
			||||||
 | 
					    const int i = threadIdx.x;
 | 
				
			||||||
 | 
					    _w += h*_N_;
 | 
				
			||||||
 | 
					    _u += h*_N_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
 | 
				
			||||||
 | 
					    float state[_N_] = {0};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    __syncthreads();
 | 
				
			||||||
 | 
					    w[i] = _w[i];
 | 
				
			||||||
 | 
					    u[i] = float(_u[i]);
 | 
				
			||||||
 | 
					    __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					        r[i] = float(_r[t]);
 | 
				
			||||||
 | 
					        k[i] = float(_k[t]);
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        const float v = float(_v[t]);
 | 
				
			||||||
 | 
					        float y = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #pragma unroll
 | 
				
			||||||
 | 
					        for (int j = 0; j < _N_; j+=4)
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            const float4& r_ = (float4&)(r[j]);
 | 
				
			||||||
 | 
					            const float4& k_ = (float4&)(k[j]);
 | 
				
			||||||
 | 
					            const float4& w_ = (float4&)(w[j]);
 | 
				
			||||||
 | 
					            const float4& u_ = (float4&)(u[j]);
 | 
				
			||||||
 | 
					            float4& s = (float4&)(state[j]);
 | 
				
			||||||
 | 
					            float4 x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            x.x = k_.x * v;
 | 
				
			||||||
 | 
					            x.y = k_.y * v;
 | 
				
			||||||
 | 
					            x.z = k_.z * v;
 | 
				
			||||||
 | 
					            x.w = k_.w * v;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            y += r_.x * (u_.x * x.x + s.x);
 | 
				
			||||||
 | 
					            y += r_.y * (u_.y * x.y + s.y);
 | 
				
			||||||
 | 
					            y += r_.z * (u_.z * x.z + s.z);
 | 
				
			||||||
 | 
					            y += r_.w * (u_.w * x.w + s.w);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            s.x = s.x * w_.x + x.x;
 | 
				
			||||||
 | 
					            s.y = s.y * w_.y + x.y;
 | 
				
			||||||
 | 
					            s.z = s.z * w_.z + x.z;
 | 
				
			||||||
 | 
					            s.w = s.w * w_.w + x.w;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        _y[t] = F(y);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename F>
 | 
				
			||||||
 | 
					__global__ void kernel_backward(const int B, const int T, const int C, const int H,
 | 
				
			||||||
 | 
					    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
 | 
				
			||||||
 | 
					    F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    const int b = blockIdx.x / H;
 | 
				
			||||||
 | 
					    const int h = blockIdx.x % H;
 | 
				
			||||||
 | 
					    const int i = threadIdx.x;
 | 
				
			||||||
 | 
					    _w += h*_N_;
 | 
				
			||||||
 | 
					    _u += h*_N_;
 | 
				
			||||||
 | 
					    __w += h*_N_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    __shared__ float w_[_N_], u_[_N_];
 | 
				
			||||||
 | 
					    __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
 | 
				
			||||||
 | 
					    __syncthreads();
 | 
				
			||||||
 | 
					    w_[i] = _w[i];
 | 
				
			||||||
 | 
					    u_[i] = float(_u[i]);
 | 
				
			||||||
 | 
					    __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const float w = w_[i];
 | 
				
			||||||
 | 
					    const float ww = __w[i];
 | 
				
			||||||
 | 
					    const float u = u_[i];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float gw = 0, gu = 0;
 | 
				
			||||||
 | 
					    const int t000 = b*T*C + h*_N_ + i;
 | 
				
			||||||
 | 
					    const int t111 = (b+1)*T*C + h*_N_ + i;
 | 
				
			||||||
 | 
					    const int t222 = t111 - 2*C;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int t = t000; t < t111; t += C)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					        v[i] = float(_v[t]);
 | 
				
			||||||
 | 
					        gy[i] = float(_gy[t]);
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        const float k = float(_k[t]);
 | 
				
			||||||
 | 
					        float gr = 0, gu_ = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #pragma unroll
 | 
				
			||||||
 | 
					        for (int j = 0; j < _N_; j++)
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            float& s = state[j];
 | 
				
			||||||
 | 
					            float x = k * v[j];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            gr += (u * x + s) * gy[j];
 | 
				
			||||||
 | 
					            gu_ += x * gy[j];
 | 
				
			||||||
 | 
					            s = s * w + x;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        _gr[t] = F(gr);
 | 
				
			||||||
 | 
					        gu += float(_r[t]) * gu_;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    _gu[b*C + h*_N_ + i] = F(gu);
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    for (int t = t000; t < t222; t += C)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					        v[i] = float(_v[t]);
 | 
				
			||||||
 | 
					        gy[i] = float(_gy[t + 2*C]);
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        const float k = float(_k[t]);
 | 
				
			||||||
 | 
					        float gw_ = 0;
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        #pragma unroll
 | 
				
			||||||
 | 
					        for (int j = 0; j < _N_; j++)
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            float& s = saaaa[j];
 | 
				
			||||||
 | 
					            float& s2 = sbbbb[j];
 | 
				
			||||||
 | 
					            float x = k * v[j];
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            float tmp = w * (x + s);
 | 
				
			||||||
 | 
					            s = tmp;
 | 
				
			||||||
 | 
					            s2 = tmp + w * s2;
 | 
				
			||||||
 | 
					            gw_ += s2 * gy[j];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        gw += float(_r[t + 2*C]) * gw_;
 | 
				
			||||||
 | 
					    }    
 | 
				
			||||||
 | 
					    _gw[b*C + h*_N_ + i] = F(ww * gw);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int t = t111 - C; t >= t000; t -= C)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					        v[i] = float(_v[t]);
 | 
				
			||||||
 | 
					        gy[i] = float(_gy[t]);
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        const float rr = float(_r[t]);
 | 
				
			||||||
 | 
					        float gk = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #pragma unroll
 | 
				
			||||||
 | 
					        for (int j = 0; j < _N_; j++)
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            float& s = scccc[j];
 | 
				
			||||||
 | 
					            float x = rr * gy[j];
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            gk += (u * x + s) * v[j];
 | 
				
			||||||
 | 
					            s = x + s * w;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        _gk[t] = F(gk);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int t = t111 - C; t >= t000; t -= C)
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					        r[i] = float(_r[t]);
 | 
				
			||||||
 | 
					        k[i] = float(_k[t]);
 | 
				
			||||||
 | 
					        __syncthreads();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        const float gyy = float(_gy[t]);
 | 
				
			||||||
 | 
					        float gv = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #pragma unroll
 | 
				
			||||||
 | 
					        for (int j = 0; j < _N_; j++)
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            float& s = sdddd[j];
 | 
				
			||||||
 | 
					            float x = gyy * r[j];
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            gv += (u_[j] * x + s) * k[j];
 | 
				
			||||||
 | 
					            s = x + s * w_[j];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        _gv[t] = F(gv);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    assert(H*_N_ == C);
 | 
				
			||||||
 | 
					    assert(_N_%4 == 0);
 | 
				
			||||||
 | 
					    kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    assert(H*_N_ == C);
 | 
				
			||||||
 | 
					    assert(_N_%4 == 0);
 | 
				
			||||||
 | 
					    kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										22
									
								
								finetune/lora/v5/cuda/wkv5_op.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								finetune/lora/v5/cuda/wkv5_op.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					#include <torch/extension.h>
 | 
				
			||||||
 | 
					#include "ATen/ATen.h"
 | 
				
			||||||
 | 
					typedef at::BFloat16 bf16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
 | 
				
			||||||
 | 
					void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
 | 
				
			||||||
 | 
					    cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
 | 
				
			||||||
 | 
					    cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
				
			||||||
 | 
					    m.def("forward", &forward, "wkv5 forward");
 | 
				
			||||||
 | 
					    m.def("backward", &backward, "wkv5 backward");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TORCH_LIBRARY(wkv5, m) {
 | 
				
			||||||
 | 
					    m.def("forward", forward);
 | 
				
			||||||
 | 
					    m.def("backward", backward);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										0
									
								
								finetune/lora/v5/src/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								finetune/lora/v5/src/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
								
								
									
										303
									
								
								finetune/lora/v5/src/binidx.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										303
									
								
								finetune/lora/v5/src/binidx.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,303 @@
 | 
				
			|||||||
 | 
					from lib2to3.pgen2 import token
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import shutil
 | 
				
			||||||
 | 
					import struct
 | 
				
			||||||
 | 
					from functools import lru_cache
 | 
				
			||||||
 | 
					from itertools import accumulate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def print_rank_0(*message):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					    # """If distributed is initialized print only on rank 0."""
 | 
				
			||||||
 | 
					    # if torch.distributed.is_initialized():
 | 
				
			||||||
 | 
					    #     if torch.distributed.get_rank() == 0:
 | 
				
			||||||
 | 
					    #         print(*message, flush=True)
 | 
				
			||||||
 | 
					    # else:
 | 
				
			||||||
 | 
					    #     print(*message, flush=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _warmup_mmap_file(path):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					    # with open(path, "rb") as stream:
 | 
				
			||||||
 | 
					    #     while stream.read(100 * 1024 * 1024):
 | 
				
			||||||
 | 
					    #         pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					dtypes = {
 | 
				
			||||||
 | 
					    1: np.uint8,
 | 
				
			||||||
 | 
					    2: np.int8,
 | 
				
			||||||
 | 
					    3: np.int16,
 | 
				
			||||||
 | 
					    4: np.int32,
 | 
				
			||||||
 | 
					    5: np.int64,
 | 
				
			||||||
 | 
					    6: float,
 | 
				
			||||||
 | 
					    7: np.double,
 | 
				
			||||||
 | 
					    8: np.uint16,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def code(dtype):
 | 
				
			||||||
 | 
					    for k in dtypes.keys():
 | 
				
			||||||
 | 
					        if dtypes[k] == dtype:
 | 
				
			||||||
 | 
					            return k
 | 
				
			||||||
 | 
					    raise ValueError(dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def index_file_path(prefix_path):
 | 
				
			||||||
 | 
					    return prefix_path + ".idx"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def data_file_path(prefix_path):
 | 
				
			||||||
 | 
					    return prefix_path + ".bin"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MMapIndexedDataset(torch.utils.data.Dataset):
 | 
				
			||||||
 | 
					    class Index(object):
 | 
				
			||||||
 | 
					        _HDR_MAGIC = b"MMIDIDX\x00\x00"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @classmethod
 | 
				
			||||||
 | 
					        def writer(cls, path, dtype):
 | 
				
			||||||
 | 
					            class _Writer(object):
 | 
				
			||||||
 | 
					                def __enter__(self):
 | 
				
			||||||
 | 
					                    self._file = open(path, "wb")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # Write Magic string so we can check the file format then opening it again.
 | 
				
			||||||
 | 
					                    self._file.write(cls._HDR_MAGIC)
 | 
				
			||||||
 | 
					                    # Write version number
 | 
				
			||||||
 | 
					                    # Little endian unsigned 64 Bit integer
 | 
				
			||||||
 | 
					                    self._file.write(struct.pack("<Q", 1))
 | 
				
			||||||
 | 
					                    # Little endian unsigned 8 Bit integer
 | 
				
			||||||
 | 
					                    self._file.write(struct.pack("<B", code(dtype)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                @staticmethod
 | 
				
			||||||
 | 
					                def _get_pointers(sizes):
 | 
				
			||||||
 | 
					                    dtype_size = dtype().itemsize
 | 
				
			||||||
 | 
					                    address = 0
 | 
				
			||||||
 | 
					                    pointers = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    for size in sizes:
 | 
				
			||||||
 | 
					                        pointers.append(address)
 | 
				
			||||||
 | 
					                        address += size * dtype_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    return pointers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                def write(self, sizes, doc_idx):
 | 
				
			||||||
 | 
					                    pointers = self._get_pointers(sizes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # Little endian unsigned 64 Bit integer
 | 
				
			||||||
 | 
					                    self._file.write(struct.pack("<Q", len(sizes)))
 | 
				
			||||||
 | 
					                    # Little endian unsigned 64 Bit integer
 | 
				
			||||||
 | 
					                    self._file.write(struct.pack("<Q", len(doc_idx)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    sizes = np.array(sizes, dtype=np.int32)
 | 
				
			||||||
 | 
					                    self._file.write(sizes.tobytes(order="C"))
 | 
				
			||||||
 | 
					                    del sizes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    pointers = np.array(pointers, dtype=np.int64)
 | 
				
			||||||
 | 
					                    self._file.write(pointers.tobytes(order="C"))
 | 
				
			||||||
 | 
					                    del pointers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    doc_idx = np.array(doc_idx, dtype=np.int64)
 | 
				
			||||||
 | 
					                    self._file.write(doc_idx.tobytes(order="C"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                def __exit__(self, exc_type, exc_val, exc_tb):
 | 
				
			||||||
 | 
					                    self._file.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return _Writer()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def __init__(self, path, skip_warmup=False):
 | 
				
			||||||
 | 
					            with open(path, "rb") as stream:
 | 
				
			||||||
 | 
					                magic_test = stream.read(9)
 | 
				
			||||||
 | 
					                assert self._HDR_MAGIC == magic_test, (
 | 
				
			||||||
 | 
					                    "Index file doesn't match expected format. "
 | 
				
			||||||
 | 
					                    "Make sure that --dataset-impl is configured properly."
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                # Little endian unsigned 64 Bit integer
 | 
				
			||||||
 | 
					                version = struct.unpack("<Q", stream.read(8))
 | 
				
			||||||
 | 
					                assert (1,) == version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Little endian unsigned 8 Bit integer
 | 
				
			||||||
 | 
					                (dtype_code,) = struct.unpack("<B", stream.read(1))
 | 
				
			||||||
 | 
					                self._dtype = dtypes[dtype_code]
 | 
				
			||||||
 | 
					                self._dtype_size = self._dtype().itemsize
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self._len = struct.unpack("<Q", stream.read(8))[0]
 | 
				
			||||||
 | 
					                self._doc_count = struct.unpack("<Q", stream.read(8))[0]
 | 
				
			||||||
 | 
					                offset = stream.tell()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not skip_warmup:
 | 
				
			||||||
 | 
					                print_rank_0("    warming up index mmap file...")
 | 
				
			||||||
 | 
					                _warmup_mmap_file(path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
 | 
				
			||||||
 | 
					            self._bin_buffer = memoryview(self._bin_buffer_mmap)
 | 
				
			||||||
 | 
					            print_rank_0("    reading sizes...")
 | 
				
			||||||
 | 
					            self._sizes = np.frombuffer(
 | 
				
			||||||
 | 
					                self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            print_rank_0("    reading pointers...")
 | 
				
			||||||
 | 
					            self._pointers = np.frombuffer(
 | 
				
			||||||
 | 
					                self._bin_buffer,
 | 
				
			||||||
 | 
					                dtype=np.int64,
 | 
				
			||||||
 | 
					                count=self._len,
 | 
				
			||||||
 | 
					                offset=offset + self._sizes.nbytes,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            print_rank_0("    reading document index...")
 | 
				
			||||||
 | 
					            self._doc_idx = np.frombuffer(
 | 
				
			||||||
 | 
					                self._bin_buffer,
 | 
				
			||||||
 | 
					                dtype=np.int64,
 | 
				
			||||||
 | 
					                count=self._doc_count,
 | 
				
			||||||
 | 
					                offset=offset + self._sizes.nbytes + self._pointers.nbytes,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def __del__(self):
 | 
				
			||||||
 | 
					            self._bin_buffer_mmap._mmap.close()
 | 
				
			||||||
 | 
					            del self._bin_buffer_mmap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @property
 | 
				
			||||||
 | 
					        def dtype(self):
 | 
				
			||||||
 | 
					            return self._dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @property
 | 
				
			||||||
 | 
					        def sizes(self):
 | 
				
			||||||
 | 
					            return self._sizes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @property
 | 
				
			||||||
 | 
					        def doc_idx(self):
 | 
				
			||||||
 | 
					            return self._doc_idx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @lru_cache(maxsize=8)
 | 
				
			||||||
 | 
					        def __getitem__(self, i):
 | 
				
			||||||
 | 
					            return self._pointers[i], self._sizes[i]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def __len__(self):
 | 
				
			||||||
 | 
					            return self._len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, path, skip_warmup=False):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._path = None
 | 
				
			||||||
 | 
					        self._index = None
 | 
				
			||||||
 | 
					        self._bin_buffer = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._do_init(path, skip_warmup)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getstate__(self):
 | 
				
			||||||
 | 
					        return self._path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __setstate__(self, state):
 | 
				
			||||||
 | 
					        self._do_init(state)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _do_init(self, path, skip_warmup):
 | 
				
			||||||
 | 
					        self._path = path
 | 
				
			||||||
 | 
					        self._index = self.Index(index_file_path(self._path), skip_warmup)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not skip_warmup:
 | 
				
			||||||
 | 
					            print_rank_0("    warming up data mmap file...")
 | 
				
			||||||
 | 
					            _warmup_mmap_file(data_file_path(self._path))
 | 
				
			||||||
 | 
					        print_rank_0("    creating numpy buffer of mmap...")
 | 
				
			||||||
 | 
					        self._bin_buffer_mmap = np.memmap(
 | 
				
			||||||
 | 
					            data_file_path(self._path), mode="r", order="C"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        print_rank_0("    creating memory view of numpy buffer...")
 | 
				
			||||||
 | 
					        self._bin_buffer = memoryview(self._bin_buffer_mmap)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __del__(self):
 | 
				
			||||||
 | 
					        self._bin_buffer_mmap._mmap.close()
 | 
				
			||||||
 | 
					        del self._bin_buffer_mmap
 | 
				
			||||||
 | 
					        del self._index
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return len(self._index)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # @lru_cache(maxsize=8)
 | 
				
			||||||
 | 
					    def __getitem__(self, idx):
 | 
				
			||||||
 | 
					        if isinstance(idx, int):
 | 
				
			||||||
 | 
					            ptr, size = self._index[idx]
 | 
				
			||||||
 | 
					            np_array = np.frombuffer(
 | 
				
			||||||
 | 
					                self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            return np_array
 | 
				
			||||||
 | 
					        elif isinstance(idx, slice):
 | 
				
			||||||
 | 
					            start, stop, step = idx.indices(len(self))
 | 
				
			||||||
 | 
					            if step != 1:
 | 
				
			||||||
 | 
					                raise ValueError("Slices into indexed_dataset must be contiguous")
 | 
				
			||||||
 | 
					            ptr = self._index._pointers[start]
 | 
				
			||||||
 | 
					            sizes = self._index._sizes[idx]
 | 
				
			||||||
 | 
					            offsets = list(accumulate(sizes))
 | 
				
			||||||
 | 
					            total_size = sum(sizes)
 | 
				
			||||||
 | 
					            np_array = np.frombuffer(
 | 
				
			||||||
 | 
					                self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            sents = np.split(np_array, offsets[:-1])
 | 
				
			||||||
 | 
					            return sents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, idx, offset=0, length=None):
 | 
				
			||||||
 | 
					        """Retrieves a single item from the dataset with the option to only
 | 
				
			||||||
 | 
					        return a portion of the item.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        get(idx) is the same as [idx] but get() does not support slicing.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        ptr, size = self._index[idx]
 | 
				
			||||||
 | 
					        if length is None:
 | 
				
			||||||
 | 
					            length = size - offset
 | 
				
			||||||
 | 
					        ptr += offset * np.dtype(self._index.dtype).itemsize
 | 
				
			||||||
 | 
					        np_array = np.frombuffer(
 | 
				
			||||||
 | 
					            self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return np_array
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def pad(self, idx, length=None):
 | 
				
			||||||
 | 
					        ptr, size = self._index[idx]
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            np_array = np.frombuffer(
 | 
				
			||||||
 | 
					                self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        except:
 | 
				
			||||||
 | 
					            np_array = np.frombuffer(
 | 
				
			||||||
 | 
					                self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            ptr0, _ = self._index[0]
 | 
				
			||||||
 | 
					            np_array0 = np.frombuffer(
 | 
				
			||||||
 | 
					                self._bin_buffer,
 | 
				
			||||||
 | 
					                dtype=self._index.dtype,
 | 
				
			||||||
 | 
					                count=length - size,
 | 
				
			||||||
 | 
					                offset=ptr0,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            np_array = np.append(np_array, np_array0)
 | 
				
			||||||
 | 
					        return np_array
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def only(self, idx):
 | 
				
			||||||
 | 
					        ptr, size = self._index[idx]
 | 
				
			||||||
 | 
					        np_array = np.frombuffer(
 | 
				
			||||||
 | 
					            self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return np_array
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def sizes(self):
 | 
				
			||||||
 | 
					        return self._index.sizes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def doc_idx(self):
 | 
				
			||||||
 | 
					        return self._index.doc_idx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_doc_idx(self):
 | 
				
			||||||
 | 
					        return self._index._doc_idx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_doc_idx(self, doc_idx_):
 | 
				
			||||||
 | 
					        self._index._doc_idx = doc_idx_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def supports_prefetch(self):
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def exists(path):
 | 
				
			||||||
 | 
					        return os.path.exists(index_file_path(path)) and os.path.exists(
 | 
				
			||||||
 | 
					            data_file_path(path)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
							
								
								
									
										241
									
								
								finetune/lora/v5/src/dataset.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										241
									
								
								finetune/lora/v5/src/dataset.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,241 @@
 | 
				
			|||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
 | 
				
			||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import json, math, random, os, sys
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.utils.data import Dataset
 | 
				
			||||||
 | 
					from pytorch_lightning.utilities import rank_zero_info
 | 
				
			||||||
 | 
					from .binidx import MMapIndexedDataset
 | 
				
			||||||
 | 
					from .utils import MaybeIsPrime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MyDataset(Dataset):
 | 
				
			||||||
 | 
					    def __init__(self, args):
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.data_type == "binidx":
 | 
				
			||||||
 | 
					            self.vocab_size = args.vocab_size
 | 
				
			||||||
 | 
					            rank_zero_info(
 | 
				
			||||||
 | 
					                f"Current vocab size = {self.vocab_size} (make sure it's correct)"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if args.my_pile_version == 1:
 | 
				
			||||||
 | 
					                self.data = MMapIndexedDataset(args.data_file)
 | 
				
			||||||
 | 
					                self.data_size = (
 | 
				
			||||||
 | 
					                    len(self.data._bin_buffer) // self.data._index._dtype_size
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                rank_zero_info(f"Data has {self.data_size} tokens.")
 | 
				
			||||||
 | 
					            elif args.my_pile_version == 2:
 | 
				
			||||||
 | 
					                data_list = (
 | 
				
			||||||
 | 
					                    open(args.data_file, "r", encoding="utf-8")
 | 
				
			||||||
 | 
					                    .read()
 | 
				
			||||||
 | 
					                    .strip()
 | 
				
			||||||
 | 
					                    .split("\n")
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                data_list = [i.strip().split(" ") for i in data_list]
 | 
				
			||||||
 | 
					                self.data = []
 | 
				
			||||||
 | 
					                self.data_size = int(data_list[-1][-1])
 | 
				
			||||||
 | 
					                rank_zero_info(f"Data has {self.data_size} chunks.")
 | 
				
			||||||
 | 
					                for d in data_list:
 | 
				
			||||||
 | 
					                    data = MMapIndexedDataset(d[0])
 | 
				
			||||||
 | 
					                    data_size = len(data._bin_buffer) // data._index._dtype_size
 | 
				
			||||||
 | 
					                    assert (data_size - args.ctx_len) == int(d[1])
 | 
				
			||||||
 | 
					                    self.data += [[int(d[-1]), int(d[1]), data]]
 | 
				
			||||||
 | 
					                # rank_zero_info(self.data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if args.my_qa_mask > 0:
 | 
				
			||||||
 | 
					                # self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
 | 
				
			||||||
 | 
					                self.data_pile = MMapIndexedDataset(
 | 
				
			||||||
 | 
					                    "/fsx/pile_deduped/pile_0.87_deduped_text_document"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                self.data_pile_size = (
 | 
				
			||||||
 | 
					                    len(self.data_pile._bin_buffer) // self.data._index._dtype_size
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.data_pile = None
 | 
				
			||||||
 | 
					                self.data_pile_size = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if args.my_pile_stage > 0:
 | 
				
			||||||
 | 
					                # assert self.data_size == 332115325534 and self.vocab_size == 50277
 | 
				
			||||||
 | 
					                self.samples_per_epoch = args.epoch_steps * args.real_bsz
 | 
				
			||||||
 | 
					                assert self.samples_per_epoch == 40320
 | 
				
			||||||
 | 
					                rank_zero_info(
 | 
				
			||||||
 | 
					                    f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                dataset_slot = self.data_size // args.ctx_len
 | 
				
			||||||
 | 
					                if args.my_pile_stage != 4:
 | 
				
			||||||
 | 
					                    assert MaybeIsPrime(args.magic_prime)
 | 
				
			||||||
 | 
					                    assert args.magic_prime % 3 == 2
 | 
				
			||||||
 | 
					                    assert (
 | 
				
			||||||
 | 
					                        args.magic_prime / dataset_slot > 0.99
 | 
				
			||||||
 | 
					                        and args.magic_prime / dataset_slot <= 1
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					        elif args.data_type == "numpy":
 | 
				
			||||||
 | 
					            self.data = np.load(args.data_file).astype("int")
 | 
				
			||||||
 | 
					            self.vocab_size = args.vocab_size
 | 
				
			||||||
 | 
					            rank_zero_info(
 | 
				
			||||||
 | 
					                f"Current vocab size = {self.vocab_size} (make sure it's correct)"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.data_size = len(self.data)
 | 
				
			||||||
 | 
					            rank_zero_info(f"Data has {self.data_size} tokens.")
 | 
				
			||||||
 | 
					        elif args.data_type == "uint16":
 | 
				
			||||||
 | 
					            self.data = (
 | 
				
			||||||
 | 
					                np.fromfile(args.data_file, dtype=np.uint16)
 | 
				
			||||||
 | 
					                .astype("int32")
 | 
				
			||||||
 | 
					                .reshape(-1, args.my_sample_len)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.vocab_size = args.vocab_size
 | 
				
			||||||
 | 
					            rank_zero_info(
 | 
				
			||||||
 | 
					                f"Current vocab size = {self.vocab_size} (make sure it's correct)"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.data_size = self.data.shape[0]
 | 
				
			||||||
 | 
					            rank_zero_info(f"Data has {self.data_size} samples.")
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if args.data_type == "dummy":
 | 
				
			||||||
 | 
					                rank_zero_info("Building dummy data...")
 | 
				
			||||||
 | 
					                self.data = ""
 | 
				
			||||||
 | 
					                for i in range(100000):
 | 
				
			||||||
 | 
					                    aa = (i) % 10000
 | 
				
			||||||
 | 
					                    bb = (i * i) % 10000
 | 
				
			||||||
 | 
					                    cc = aa + bb
 | 
				
			||||||
 | 
					                    self.data += f".{aa}+{bb}={cc}."
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.data = open(args.data_file, "r", encoding=args.data_type).read()
 | 
				
			||||||
 | 
					            rank_zero_info("Building token list...")
 | 
				
			||||||
 | 
					            unique = sorted(list(set(self.data)))
 | 
				
			||||||
 | 
					            self.vocab_size = len(unique)
 | 
				
			||||||
 | 
					            # rank_zero_info()
 | 
				
			||||||
 | 
					            # for u in unique:
 | 
				
			||||||
 | 
					            #     print(u, end=' ')
 | 
				
			||||||
 | 
					            # rank_zero_info('\n\n')
 | 
				
			||||||
 | 
					            xx = 0
 | 
				
			||||||
 | 
					            xxObj = {}
 | 
				
			||||||
 | 
					            for u in unique:
 | 
				
			||||||
 | 
					                xxObj[xx] = u
 | 
				
			||||||
 | 
					                xx += 1
 | 
				
			||||||
 | 
					            with open(
 | 
				
			||||||
 | 
					                f"{args.proj_dir}/vocab.json", "w", encoding="utf-8"
 | 
				
			||||||
 | 
					            ) as vocab_file:
 | 
				
			||||||
 | 
					                vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
 | 
				
			||||||
 | 
					            self.data_size = len(self.data)
 | 
				
			||||||
 | 
					            rank_zero_info(
 | 
				
			||||||
 | 
					                f"Data has {self.data_size} tokens, {self.vocab_size} vocab size."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.stoi = {ch: i for i, ch in enumerate(unique)}
 | 
				
			||||||
 | 
					            self.itos = {i: ch for i, ch in enumerate(unique)}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return self.args.epoch_steps * self.args.micro_bsz
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, idx):
 | 
				
			||||||
 | 
					        args = self.args
 | 
				
			||||||
 | 
					        rank = self.global_rank
 | 
				
			||||||
 | 
					        epoch = self.real_epoch
 | 
				
			||||||
 | 
					        world_size = self.world_size
 | 
				
			||||||
 | 
					        # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.data_type == "uint16":
 | 
				
			||||||
 | 
					            i = np.random.randint(0, self.data_size - 1)
 | 
				
			||||||
 | 
					            dix = self.data[i]
 | 
				
			||||||
 | 
					            x = torch.tensor(dix[:-1], dtype=torch.long)
 | 
				
			||||||
 | 
					            y = torch.tensor(dix[1:], dtype=torch.long)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            ctx_len = args.ctx_len
 | 
				
			||||||
 | 
					            req_len = ctx_len + 1
 | 
				
			||||||
 | 
					            magic_prime = args.magic_prime
 | 
				
			||||||
 | 
					            data = self.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if args.my_pile_stage > 0:
 | 
				
			||||||
 | 
					                ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if args.my_qa_mask > 0:
 | 
				
			||||||
 | 
					                    ii_orig = ii
 | 
				
			||||||
 | 
					                    if ii % 2 == 0:
 | 
				
			||||||
 | 
					                        ii = -1
 | 
				
			||||||
 | 
					                        data = self.data_pile
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        ii = ii // 2
 | 
				
			||||||
 | 
					                if data == self.data_pile:
 | 
				
			||||||
 | 
					                    i = np.random.randint(0, self.data_pile_size - req_len)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    if args.my_pile_stage == 4 or ii < args.my_random_steps:
 | 
				
			||||||
 | 
					                        # cheat: pick a random spot in dataset
 | 
				
			||||||
 | 
					                        if args.my_pile_version == 1:
 | 
				
			||||||
 | 
					                            i = np.random.randint(0, self.data_size - req_len)
 | 
				
			||||||
 | 
					                        else:
 | 
				
			||||||
 | 
					                            i = np.random.randint(0, self.data_size)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        ii = ii - args.my_random_steps
 | 
				
			||||||
 | 
					                        factor = (math.sqrt(5) - 1) / 2
 | 
				
			||||||
 | 
					                        factor = int(magic_prime * factor)
 | 
				
			||||||
 | 
					                        i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
 | 
				
			||||||
 | 
					                        i = i + args.my_pile_shift
 | 
				
			||||||
 | 
					                # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # cheat: pick a random spot in dataset
 | 
				
			||||||
 | 
					                i = np.random.randint(0, self.data_size - req_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if args.data_type == "binidx":
 | 
				
			||||||
 | 
					                if args.my_pile_version == 1:
 | 
				
			||||||
 | 
					                    dix = data.get(idx=0, offset=i, length=req_len).astype(int)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    # self.data : cutoff, chunk_count, data
 | 
				
			||||||
 | 
					                    for j in range(len(data)):
 | 
				
			||||||
 | 
					                        if i < data[j][0]:
 | 
				
			||||||
 | 
					                            ii = i
 | 
				
			||||||
 | 
					                            i = (i - (data[j - 1][0] if j > 0 else 0)) % data[j][1]
 | 
				
			||||||
 | 
					                            dix = (
 | 
				
			||||||
 | 
					                                data[j][2]
 | 
				
			||||||
 | 
					                                .get(idx=0, offset=i, length=req_len)
 | 
				
			||||||
 | 
					                                .astype(int)
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                            # print(ii, j, i)
 | 
				
			||||||
 | 
					                            break
 | 
				
			||||||
 | 
					            elif args.data_type == "numpy":
 | 
				
			||||||
 | 
					                dix = data[i : i + req_len]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                dix = [self.stoi[s] for s in data[i : i + req_len]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if args.my_qa_mask == 1:
 | 
				
			||||||
 | 
					                if data == self.data_pile:
 | 
				
			||||||
 | 
					                    z = [1] * ctx_len
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    z = [0] * ctx_len
 | 
				
			||||||
 | 
					                    z_sum = 0
 | 
				
			||||||
 | 
					                    isGood = False
 | 
				
			||||||
 | 
					                    for i in range(3, ctx_len):
 | 
				
			||||||
 | 
					                        if (
 | 
				
			||||||
 | 
					                            dix[i] == 27
 | 
				
			||||||
 | 
					                            and dix[i - 1] == 34
 | 
				
			||||||
 | 
					                            and dix[i - 2] == 187
 | 
				
			||||||
 | 
					                            and dix[i - 3] == 187
 | 
				
			||||||
 | 
					                        ):
 | 
				
			||||||
 | 
					                            isGood = True
 | 
				
			||||||
 | 
					                        if dix[i] == 0:
 | 
				
			||||||
 | 
					                            isGood = False
 | 
				
			||||||
 | 
					                        if isGood:
 | 
				
			||||||
 | 
					                            z[i] = 1
 | 
				
			||||||
 | 
					                            z_sum += 1
 | 
				
			||||||
 | 
					                    if z_sum == 0:
 | 
				
			||||||
 | 
					                        z = [1] * ctx_len
 | 
				
			||||||
 | 
					                        i = np.random.randint(0, self.data_pile_size - req_len)
 | 
				
			||||||
 | 
					                        dix = self.data_pile.get(
 | 
				
			||||||
 | 
					                            idx=0, offset=i, length=req_len
 | 
				
			||||||
 | 
					                        ).astype(int)
 | 
				
			||||||
 | 
					                z = torch.tensor(z, dtype=torch.bfloat16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            x = torch.tensor(dix[:-1], dtype=torch.long)
 | 
				
			||||||
 | 
					            y = torch.tensor(dix[1:], dtype=torch.long)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # if ii_orig < 50:
 | 
				
			||||||
 | 
					            #     # if rank == 1:
 | 
				
			||||||
 | 
					            #     print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
 | 
				
			||||||
 | 
					            # else:
 | 
				
			||||||
 | 
					            #     exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if args.my_qa_mask == 1:
 | 
				
			||||||
 | 
					                return x, y, z
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return x, y
 | 
				
			||||||
							
								
								
									
										819
									
								
								finetune/lora/v5/src/model.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										819
									
								
								finetune/lora/v5/src/model.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,819 @@
 | 
				
			|||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
 | 
				
			||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					import functools
 | 
				
			||||||
 | 
					import os, math, gc, importlib
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# torch._C._jit_set_profiling_executor(True)
 | 
				
			||||||
 | 
					# torch._C._jit_set_profiling_mode(True)
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from torch.utils.checkpoint import checkpoint as torch_checkpoint
 | 
				
			||||||
 | 
					from torch.nn import functional as F
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
 | 
				
			||||||
 | 
					from pytorch_lightning.strategies import DeepSpeedStrategy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if importlib.util.find_spec("deepspeed"):
 | 
				
			||||||
 | 
					    import deepspeed
 | 
				
			||||||
 | 
					    from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# lora-config
 | 
				
			||||||
 | 
					LORA_CONFIG = {
 | 
				
			||||||
 | 
					    "r": 0,
 | 
				
			||||||
 | 
					    "alpha": 0,
 | 
				
			||||||
 | 
					    "dropout": 0,
 | 
				
			||||||
 | 
					    "parts": {"att", "ln", "time"},
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"])
 | 
				
			||||||
 | 
					except:
 | 
				
			||||||
 | 
					    os.environ["RWKV_MY_TESTING"] = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def __nop(ob):
 | 
				
			||||||
 | 
					    return ob
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MyModule = nn.Module
 | 
				
			||||||
 | 
					MyFunction = __nop
 | 
				
			||||||
 | 
					if os.environ["RWKV_JIT_ON"] == "1":
 | 
				
			||||||
 | 
					    MyModule = torch.jit.ScriptModule
 | 
				
			||||||
 | 
					    MyFunction = torch.jit.script_method
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					# CUDA Kernel
 | 
				
			||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from torch.utils.cpp_extension import load
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"])
 | 
				
			||||||
 | 
					wkv5_cuda = load(
 | 
				
			||||||
 | 
					    name="wkv5",
 | 
				
			||||||
 | 
					    sources=[
 | 
				
			||||||
 | 
					        "finetune/lora/v5/cuda/wkv5_op.cpp",
 | 
				
			||||||
 | 
					        f"finetune/lora/v5/cuda/wkv5_cuda.cu",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    verbose=True,
 | 
				
			||||||
 | 
					    extra_cuda_cflags=[
 | 
				
			||||||
 | 
					        "-res-usage",
 | 
				
			||||||
 | 
					        "--use_fast_math",
 | 
				
			||||||
 | 
					        "-O3",
 | 
				
			||||||
 | 
					        "-Xptxas -O3",
 | 
				
			||||||
 | 
					        "--extra-device-vectorization",
 | 
				
			||||||
 | 
					        f"-D_N_={HEAD_SIZE}",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WKV_5(torch.autograd.Function):
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def forward(ctx, B, T, C, H, r, k, v, w, u):
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            assert r.dtype == torch.bfloat16
 | 
				
			||||||
 | 
					            assert k.dtype == torch.bfloat16
 | 
				
			||||||
 | 
					            assert v.dtype == torch.bfloat16
 | 
				
			||||||
 | 
					            assert w.dtype == torch.bfloat16
 | 
				
			||||||
 | 
					            assert u.dtype == torch.bfloat16
 | 
				
			||||||
 | 
					            assert HEAD_SIZE == C // H
 | 
				
			||||||
 | 
					            ctx.B = B
 | 
				
			||||||
 | 
					            ctx.T = T
 | 
				
			||||||
 | 
					            ctx.C = C
 | 
				
			||||||
 | 
					            ctx.H = H
 | 
				
			||||||
 | 
					            assert r.is_contiguous()
 | 
				
			||||||
 | 
					            assert k.is_contiguous()
 | 
				
			||||||
 | 
					            assert v.is_contiguous()
 | 
				
			||||||
 | 
					            assert w.is_contiguous()
 | 
				
			||||||
 | 
					            assert u.is_contiguous()
 | 
				
			||||||
 | 
					            ew = (-torch.exp(w.float())).contiguous()
 | 
				
			||||||
 | 
					            eew = (torch.exp(ew)).contiguous()
 | 
				
			||||||
 | 
					            ctx.save_for_backward(r, k, v, eew, ew, u)
 | 
				
			||||||
 | 
					            y = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C),
 | 
				
			||||||
 | 
					                device=r.device,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					            )  # .uniform_(-1, 1)
 | 
				
			||||||
 | 
					            wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
 | 
				
			||||||
 | 
					            return y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def backward(ctx, gy):
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            assert gy.dtype == torch.bfloat16
 | 
				
			||||||
 | 
					            B = ctx.B
 | 
				
			||||||
 | 
					            T = ctx.T
 | 
				
			||||||
 | 
					            C = ctx.C
 | 
				
			||||||
 | 
					            H = ctx.H
 | 
				
			||||||
 | 
					            assert gy.is_contiguous()
 | 
				
			||||||
 | 
					            r, k, v, eew, ew, u = ctx.saved_tensors
 | 
				
			||||||
 | 
					            gr = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C),
 | 
				
			||||||
 | 
					                device=gy.device,
 | 
				
			||||||
 | 
					                requires_grad=False,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					            )  # .uniform_(-1, 1)
 | 
				
			||||||
 | 
					            gk = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C),
 | 
				
			||||||
 | 
					                device=gy.device,
 | 
				
			||||||
 | 
					                requires_grad=False,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					            )  # .uniform_(-1, 1)
 | 
				
			||||||
 | 
					            gv = torch.empty(
 | 
				
			||||||
 | 
					                (B, T, C),
 | 
				
			||||||
 | 
					                device=gy.device,
 | 
				
			||||||
 | 
					                requires_grad=False,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					            )  # .uniform_(-1, 1)
 | 
				
			||||||
 | 
					            gw = torch.empty(
 | 
				
			||||||
 | 
					                (B, C),
 | 
				
			||||||
 | 
					                device=gy.device,
 | 
				
			||||||
 | 
					                requires_grad=False,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					            )  # .uniform_(-1, 1)
 | 
				
			||||||
 | 
					            gu = torch.empty(
 | 
				
			||||||
 | 
					                (B, C),
 | 
				
			||||||
 | 
					                device=gy.device,
 | 
				
			||||||
 | 
					                requires_grad=False,
 | 
				
			||||||
 | 
					                dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                memory_format=torch.contiguous_format,
 | 
				
			||||||
 | 
					            )  # .uniform_(-1, 1)
 | 
				
			||||||
 | 
					            wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
 | 
				
			||||||
 | 
					            gw = torch.sum(gw, 0).view(H, C // H)
 | 
				
			||||||
 | 
					            gu = torch.sum(gu, 0).view(H, C // H)
 | 
				
			||||||
 | 
					            return (None, None, None, None, gr, gk, gv, gw, gu)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
 | 
				
			||||||
 | 
					    return WKV_5.apply(B, T, C, H, r, k, v, w, u)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#################################################################
 | 
				
			||||||
 | 
					class LoraLinear(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, in_features: int, out_features: int, bias: bool):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.weight = nn.Parameter(torch.empty((out_features, in_features)))
 | 
				
			||||||
 | 
					        assert bias == False, "Biased LoraLinear not supported"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        r, alpha, dropout = (
 | 
				
			||||||
 | 
					            LORA_CONFIG["r"],
 | 
				
			||||||
 | 
					            LORA_CONFIG["alpha"],
 | 
				
			||||||
 | 
					            LORA_CONFIG["dropout"],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.lora_A = nn.Parameter(torch.empty(r, in_features))
 | 
				
			||||||
 | 
					        self.lora_B = nn.Parameter(torch.empty(out_features, r))
 | 
				
			||||||
 | 
					        self.lora_dropout = nn.Dropout(dropout)
 | 
				
			||||||
 | 
					        self.scaling = alpha / r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
 | 
				
			||||||
 | 
					        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
 | 
				
			||||||
 | 
					        nn.init.zeros_(self.lora_B)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        return F.linear(x, self.weight) + self.scaling * F.linear(
 | 
				
			||||||
 | 
					            F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@functools.wraps(LoraLinear)
 | 
				
			||||||
 | 
					def make_linear_att(*args, **kwargs):
 | 
				
			||||||
 | 
					    if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
 | 
				
			||||||
 | 
					        return LoraLinear(*args, **kwargs)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return nn.Linear(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@functools.wraps(LoraLinear)
 | 
				
			||||||
 | 
					def make_linear_ffn(*args, **kwargs):
 | 
				
			||||||
 | 
					    if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
 | 
				
			||||||
 | 
					        return LoraLinear(*args, **kwargs)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return nn.Linear(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RWKV_TimeMix_RWKV5(MyModule):
 | 
				
			||||||
 | 
					    def __init__(self, args, layer_id):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					        self.layer_id = layer_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.head_size = args.head_size_a
 | 
				
			||||||
 | 
					        assert HEAD_SIZE == self.head_size  # change HEAD_SIZE to match args.head_size_a
 | 
				
			||||||
 | 
					        self.n_head = args.dim_att // self.head_size
 | 
				
			||||||
 | 
					        assert args.dim_att % self.n_head == 0
 | 
				
			||||||
 | 
					        self.head_size_divisor = args.head_size_divisor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            ratio_0_to_1 = layer_id / (args.n_layer - 1)  # 0 to 1
 | 
				
			||||||
 | 
					            ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
 | 
				
			||||||
 | 
					            ddd = torch.ones(1, 1, args.n_embd)
 | 
				
			||||||
 | 
					            for i in range(args.n_embd):
 | 
				
			||||||
 | 
					                ddd[0, 0, i] = i / args.n_embd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # fancy time_mix
 | 
				
			||||||
 | 
					            self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
 | 
				
			||||||
 | 
					            self.time_mix_v = nn.Parameter(
 | 
				
			||||||
 | 
					                torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
 | 
				
			||||||
 | 
					            self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # fancy time_decay
 | 
				
			||||||
 | 
					            decay_speed = torch.ones(args.dim_att)
 | 
				
			||||||
 | 
					            for n in range(args.dim_att):
 | 
				
			||||||
 | 
					                decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
 | 
				
			||||||
 | 
					                    0.7 + 1.3 * ratio_0_to_1
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            self.time_decay = nn.Parameter(
 | 
				
			||||||
 | 
					                decay_speed.reshape(self.n_head, self.head_size)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            tmp = torch.zeros(args.dim_att)
 | 
				
			||||||
 | 
					            for n in range(args.dim_att):
 | 
				
			||||||
 | 
					                zigzag = ((n + 1) % 3 - 1) * 0.1
 | 
				
			||||||
 | 
					                tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
 | 
				
			||||||
 | 
					        self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
 | 
				
			||||||
 | 
					        self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
 | 
				
			||||||
 | 
					        self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
 | 
				
			||||||
 | 
					        self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @MyFunction
 | 
				
			||||||
 | 
					    def jit_func(self, x):
 | 
				
			||||||
 | 
					        B, T, C = x.size()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        xx = self.time_shift(
 | 
				
			||||||
 | 
					            x
 | 
				
			||||||
 | 
					        )  # Mix x with the previous timestep to produce xk, xv, xr
 | 
				
			||||||
 | 
					        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
 | 
				
			||||||
 | 
					        xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
 | 
				
			||||||
 | 
					        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
 | 
				
			||||||
 | 
					        xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        r = self.receptance(xr)
 | 
				
			||||||
 | 
					        k = self.key(xk)
 | 
				
			||||||
 | 
					        v = self.value(xv)
 | 
				
			||||||
 | 
					        g = F.silu(self.gate(xg))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return r, k, v, g
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @MyFunction
 | 
				
			||||||
 | 
					    def jit_func_2(self, x, g):
 | 
				
			||||||
 | 
					        B, T, C = x.size()
 | 
				
			||||||
 | 
					        x = x.view(B * T, C)
 | 
				
			||||||
 | 
					        x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
 | 
				
			||||||
 | 
					        x = self.output(x * g)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        B, T, C = x.size()
 | 
				
			||||||
 | 
					        H = self.n_head
 | 
				
			||||||
 | 
					        r, k, v, g = self.jit_func(x)
 | 
				
			||||||
 | 
					        x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.jit_func_2(x, g)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RWKV_ChannelMix(MyModule):
 | 
				
			||||||
 | 
					    def __init__(self, args, layer_id):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					        self.layer_id = layer_id
 | 
				
			||||||
 | 
					        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():  # fancy init of time_mix
 | 
				
			||||||
 | 
					            ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
 | 
				
			||||||
 | 
					            ddd = torch.ones(1, 1, args.n_embd)
 | 
				
			||||||
 | 
					            for i in range(args.n_embd):
 | 
				
			||||||
 | 
					                ddd[0, 0, i] = i / args.n_embd
 | 
				
			||||||
 | 
					            self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
 | 
				
			||||||
 | 
					            self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
 | 
				
			||||||
 | 
					        self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
 | 
				
			||||||
 | 
					        self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @MyFunction
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        xx = self.time_shift(x)
 | 
				
			||||||
 | 
					        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
 | 
				
			||||||
 | 
					        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
 | 
				
			||||||
 | 
					        k = self.key(xk)
 | 
				
			||||||
 | 
					        k = torch.relu(k) ** 2
 | 
				
			||||||
 | 
					        kv = self.value(k)
 | 
				
			||||||
 | 
					        return torch.sigmoid(self.receptance(xr)) * kv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MishGLU(MyModule):
 | 
				
			||||||
 | 
					    def __init__(self, args, layer_id):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					        self.layer_id = layer_id
 | 
				
			||||||
 | 
					        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            x = torch.ones(1, 1, args.n_embd)
 | 
				
			||||||
 | 
					            for i in range(args.n_embd):
 | 
				
			||||||
 | 
					                x[0, 0, i] = i / args.n_embd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
 | 
				
			||||||
 | 
					            self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
 | 
				
			||||||
 | 
					            self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
 | 
				
			||||||
 | 
					            self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
 | 
				
			||||||
 | 
					            self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @MyFunction
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        xx = self.time_shift(x)
 | 
				
			||||||
 | 
					        xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
 | 
				
			||||||
 | 
					        xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
 | 
				
			||||||
 | 
					        a = self.aa(xa)
 | 
				
			||||||
 | 
					        b = self.bb(xb)
 | 
				
			||||||
 | 
					        return self.value(a * F.mish(b))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					# The RWKV Model with our blocks
 | 
				
			||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Block(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, args, layer_id):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					        self.layer_id = layer_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.ln1 = nn.LayerNorm(args.n_embd)
 | 
				
			||||||
 | 
					        self.ln2 = nn.LayerNorm(args.n_embd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.layer_id == 0:
 | 
				
			||||||
 | 
					            self.ln0 = nn.LayerNorm(args.n_embd)
 | 
				
			||||||
 | 
					            if args.my_pos_emb > 0:
 | 
				
			||||||
 | 
					                self.pos_emb_x = nn.Parameter(
 | 
				
			||||||
 | 
					                    torch.zeros((1, args.my_pos_emb, args.n_embd))
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                self.pos_emb_y = nn.Parameter(
 | 
				
			||||||
 | 
					                    torch.zeros((args.my_pos_emb, 1, args.n_embd))
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.layer_id == 0 and self.args.pre_ffn > 0:
 | 
				
			||||||
 | 
					            self.ffnPre = RWKV_ChannelMix(args, 0)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.att = RWKV_TimeMix_RWKV5(args, layer_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if "g" in os.environ["RWKV_MY_TESTING"]:
 | 
				
			||||||
 | 
					            self.ffn = MishGLU(args, layer_id)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.ffn = RWKV_ChannelMix(args, layer_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
 | 
				
			||||||
 | 
					            self.tiny_ln = nn.LayerNorm(args.n_embd)
 | 
				
			||||||
 | 
					            self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
 | 
				
			||||||
 | 
					            self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
 | 
				
			||||||
 | 
					            self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
 | 
				
			||||||
 | 
					            self.register_buffer(
 | 
				
			||||||
 | 
					                "tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.dropout > 0:
 | 
				
			||||||
 | 
					            self.drop0 = nn.Dropout(p=args.dropout)
 | 
				
			||||||
 | 
					            self.drop1 = nn.Dropout(p=args.dropout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x, x_emb=None):
 | 
				
			||||||
 | 
					        args = self.args
 | 
				
			||||||
 | 
					        B, T, C = x.size()
 | 
				
			||||||
 | 
					        if self.layer_id == 0:
 | 
				
			||||||
 | 
					            x = self.ln0(x)
 | 
				
			||||||
 | 
					            if args.my_pos_emb > 0:
 | 
				
			||||||
 | 
					                pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :]
 | 
				
			||||||
 | 
					                x = x + pos_emb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.args.dropout == 0:
 | 
				
			||||||
 | 
					            if self.layer_id == 0 and args.pre_ffn > 0:
 | 
				
			||||||
 | 
					                x = x + self.ffnPre(self.ln1(x))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                x = x + self.att(self.ln1(x))
 | 
				
			||||||
 | 
					            x = x + self.ffn(self.ln2(x))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if self.layer_id == 0 and args.pre_ffn > 0:
 | 
				
			||||||
 | 
					                x = self.drop0(x + self.ffnPre(self.ln1(x)))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                x = self.drop0(x + self.att(self.ln1(x)))
 | 
				
			||||||
 | 
					            x = self.drop1(x + self.ffn(self.ln2(x)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
 | 
				
			||||||
 | 
					            xx = self.tiny_ln(x)
 | 
				
			||||||
 | 
					            q = self.tiny_q(xx)[:, :T, :]
 | 
				
			||||||
 | 
					            k = self.tiny_k(xx)[:, :T, :]
 | 
				
			||||||
 | 
					            c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
 | 
				
			||||||
 | 
					            c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
 | 
				
			||||||
 | 
					            x = x + c @ self.tiny_v(x_emb)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class L2Wrap(torch.autograd.Function):
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def forward(ctx, loss, y):
 | 
				
			||||||
 | 
					        ctx.save_for_backward(y)
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def backward(ctx, grad_output):
 | 
				
			||||||
 | 
					        y = ctx.saved_tensors[0]
 | 
				
			||||||
 | 
					        # to encourage the logits to be close to 0
 | 
				
			||||||
 | 
					        factor = 1e-4 / (y.shape[0] * y.shape[1])
 | 
				
			||||||
 | 
					        maxx, ids = torch.max(y, -1, keepdim=True)
 | 
				
			||||||
 | 
					        gy = torch.zeros_like(y)
 | 
				
			||||||
 | 
					        gy.scatter_(-1, ids, maxx * factor)
 | 
				
			||||||
 | 
					        return (grad_output, gy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RWKV(pl.LightningModule):
 | 
				
			||||||
 | 
					    def __init__(self, args):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					        if not hasattr(args, "dim_att"):
 | 
				
			||||||
 | 
					            args.dim_att = args.n_embd
 | 
				
			||||||
 | 
					        if not hasattr(args, "dim_ffn"):
 | 
				
			||||||
 | 
					            args.dim_ffn = args.n_embd * 4
 | 
				
			||||||
 | 
					        if not hasattr(args, "tiny_att_layer"):
 | 
				
			||||||
 | 
					            args.tiny_att_layer = -1
 | 
				
			||||||
 | 
					        if not hasattr(args, "tiny_att_dim"):
 | 
				
			||||||
 | 
					            args.tiny_att_dim = -1
 | 
				
			||||||
 | 
					        assert args.n_embd % 32 == 0
 | 
				
			||||||
 | 
					        assert args.dim_att % 32 == 0
 | 
				
			||||||
 | 
					        assert args.dim_ffn % 32 == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.emb = nn.Embedding(args.vocab_size, args.n_embd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.ln_out = nn.LayerNorm(args.n_embd)
 | 
				
			||||||
 | 
					        self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.head_qk > 0:
 | 
				
			||||||
 | 
					            self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
 | 
				
			||||||
 | 
					            self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
 | 
				
			||||||
 | 
					            self.register_buffer(
 | 
				
			||||||
 | 
					                "copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        if args.dropout > 0:
 | 
				
			||||||
 | 
					            self.drop0 = nn.Dropout(p=args.dropout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def configure_optimizers(self):
 | 
				
			||||||
 | 
					        args = self.args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lr_decay = set()
 | 
				
			||||||
 | 
					        lr_1x = set()
 | 
				
			||||||
 | 
					        lr_2x = set()
 | 
				
			||||||
 | 
					        lr_3x = set()
 | 
				
			||||||
 | 
					        for n, p in self.named_parameters():
 | 
				
			||||||
 | 
					            if ("time_mix" in n) and (args.layerwise_lr > 0):
 | 
				
			||||||
 | 
					                if args.my_pile_stage == 2:
 | 
				
			||||||
 | 
					                    lr_2x.add(n)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    lr_1x.add(n)
 | 
				
			||||||
 | 
					            elif ("time_decay" in n) and (args.layerwise_lr > 0):
 | 
				
			||||||
 | 
					                if args.my_pile_stage == 2:
 | 
				
			||||||
 | 
					                    lr_3x.add(n)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    lr_2x.add(n)
 | 
				
			||||||
 | 
					            elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
 | 
				
			||||||
 | 
					                if args.my_pile_stage == 2:
 | 
				
			||||||
 | 
					                    lr_2x.add(n)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    lr_1x.add(n)
 | 
				
			||||||
 | 
					            elif ("time_first" in n) and (args.layerwise_lr > 0):
 | 
				
			||||||
 | 
					                lr_3x.add(n)
 | 
				
			||||||
 | 
					            elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
 | 
				
			||||||
 | 
					                lr_decay.add(n)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                lr_1x.add(n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lr_decay = sorted(list(lr_decay))
 | 
				
			||||||
 | 
					        lr_1x = sorted(list(lr_1x))
 | 
				
			||||||
 | 
					        lr_2x = sorted(list(lr_2x))
 | 
				
			||||||
 | 
					        lr_3x = sorted(list(lr_3x))
 | 
				
			||||||
 | 
					        # print('decay', lr_decay)
 | 
				
			||||||
 | 
					        # print('1x', lr_1x)
 | 
				
			||||||
 | 
					        # print('2x', lr_2x)
 | 
				
			||||||
 | 
					        # print('3x', lr_3x)
 | 
				
			||||||
 | 
					        param_dict = {n: p for n, p in self.named_parameters()}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.layerwise_lr > 0:
 | 
				
			||||||
 | 
					            if args.my_pile_stage == 2:
 | 
				
			||||||
 | 
					                optim_groups = [
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_1x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 1.0,
 | 
				
			||||||
 | 
					                    },
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_2x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 5.0,
 | 
				
			||||||
 | 
					                    },  # test: 2e-3 / args.lr_init},
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_3x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 5.0,
 | 
				
			||||||
 | 
					                    },  # test: 3e-3 / args.lr_init},
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                optim_groups = [
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_1x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 1.0,
 | 
				
			||||||
 | 
					                    },
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_2x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 2.0,
 | 
				
			||||||
 | 
					                    },
 | 
				
			||||||
 | 
					                    {
 | 
				
			||||||
 | 
					                        "params": [param_dict[n] for n in lr_3x],
 | 
				
			||||||
 | 
					                        "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                        "my_lr_scale": 3.0,
 | 
				
			||||||
 | 
					                    },
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            optim_groups = [
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "params": [param_dict[n] for n in lr_1x],
 | 
				
			||||||
 | 
					                    "weight_decay": 0.0,
 | 
				
			||||||
 | 
					                    "my_lr_scale": 1.0,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.weight_decay > 0:
 | 
				
			||||||
 | 
					            optim_groups += [
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "params": [param_dict[n] for n in lr_decay],
 | 
				
			||||||
 | 
					                    "weight_decay": args.weight_decay,
 | 
				
			||||||
 | 
					                    "my_lr_scale": 1.0,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					            if self.deepspeed_offload:
 | 
				
			||||||
 | 
					                return DeepSpeedCPUAdam(
 | 
				
			||||||
 | 
					                    optim_groups,
 | 
				
			||||||
 | 
					                    lr=self.args.lr_init,
 | 
				
			||||||
 | 
					                    betas=self.args.betas,
 | 
				
			||||||
 | 
					                    eps=self.args.adam_eps,
 | 
				
			||||||
 | 
					                    bias_correction=True,
 | 
				
			||||||
 | 
					                    adamw_mode=True,
 | 
				
			||||||
 | 
					                    amsgrad=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            return FusedAdam(
 | 
				
			||||||
 | 
					                optim_groups,
 | 
				
			||||||
 | 
					                lr=self.args.lr_init,
 | 
				
			||||||
 | 
					                betas=self.args.betas,
 | 
				
			||||||
 | 
					                eps=self.args.adam_eps,
 | 
				
			||||||
 | 
					                bias_correction=True,
 | 
				
			||||||
 | 
					                adam_w_mode=True,
 | 
				
			||||||
 | 
					                amsgrad=False,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if self.deepspeed_offload:
 | 
				
			||||||
 | 
					                return DeepSpeedCPUAdam(
 | 
				
			||||||
 | 
					                    optim_groups,
 | 
				
			||||||
 | 
					                    lr=self.args.lr_init,
 | 
				
			||||||
 | 
					                    betas=self.args.betas,
 | 
				
			||||||
 | 
					                    eps=self.args.adam_eps,
 | 
				
			||||||
 | 
					                    bias_correction=True,
 | 
				
			||||||
 | 
					                    adamw_mode=False,
 | 
				
			||||||
 | 
					                    weight_decay=0,
 | 
				
			||||||
 | 
					                    amsgrad=False,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            return FusedAdam(
 | 
				
			||||||
 | 
					                optim_groups,
 | 
				
			||||||
 | 
					                lr=self.args.lr_init,
 | 
				
			||||||
 | 
					                betas=self.args.betas,
 | 
				
			||||||
 | 
					                eps=self.args.adam_eps,
 | 
				
			||||||
 | 
					                bias_correction=True,
 | 
				
			||||||
 | 
					                adam_w_mode=False,
 | 
				
			||||||
 | 
					                weight_decay=0,
 | 
				
			||||||
 | 
					                amsgrad=False,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def deepspeed_offload(self) -> bool:
 | 
				
			||||||
 | 
					        strategy = self.trainer.strategy
 | 
				
			||||||
 | 
					        if isinstance(strategy, DeepSpeedStrategy):
 | 
				
			||||||
 | 
					            cfg = strategy.config["zero_optimization"]
 | 
				
			||||||
 | 
					            return cfg.get("offload_optimizer") or cfg.get("offload_param")
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, idx):
 | 
				
			||||||
 | 
					        args = self.args
 | 
				
			||||||
 | 
					        B, T = idx.size()
 | 
				
			||||||
 | 
					        assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.emb(idx)
 | 
				
			||||||
 | 
					        x_emb = x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.dropout > 0:
 | 
				
			||||||
 | 
					            x = self.drop0(x)
 | 
				
			||||||
 | 
					        if args.tiny_att_dim > 0:
 | 
				
			||||||
 | 
					            for block in self.blocks:
 | 
				
			||||||
 | 
					                if args.grad_cp == 1:
 | 
				
			||||||
 | 
					                    if args.lora:
 | 
				
			||||||
 | 
					                        x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    x = block(x, x_emb)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            for block in self.blocks:
 | 
				
			||||||
 | 
					                if args.grad_cp == 1:
 | 
				
			||||||
 | 
					                    if args.lora:
 | 
				
			||||||
 | 
					                        x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        x = deepspeed.checkpointing.checkpoint(block, x)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    x = block(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x = self.ln_out(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.head_qk > 0:
 | 
				
			||||||
 | 
					            q = self.head_q(x)[:, :T, :]
 | 
				
			||||||
 | 
					            k = self.head_k(x)[:, :T, :]
 | 
				
			||||||
 | 
					            c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
 | 
				
			||||||
 | 
					            c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if "32" in os.environ["RWKV_FLOAT_MODE"]:
 | 
				
			||||||
 | 
					                c = c @ F.one_hot(idx, num_classes=args.vocab_size)
 | 
				
			||||||
 | 
					            elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
 | 
				
			||||||
 | 
					                c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
 | 
				
			||||||
 | 
					            elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
				
			||||||
 | 
					                c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            x = self.head(x) + c
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            x = self.head(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step(self, batch, batch_idx):
 | 
				
			||||||
 | 
					        args = self.args
 | 
				
			||||||
 | 
					        if args.my_qa_mask != 1:
 | 
				
			||||||
 | 
					            idx, targets = batch
 | 
				
			||||||
 | 
					            logits = self(idx)
 | 
				
			||||||
 | 
					            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
 | 
				
			||||||
 | 
					            # if '0' in os.environ["RWKV_MY_TESTING"]:
 | 
				
			||||||
 | 
					            #     print('logits', logits)
 | 
				
			||||||
 | 
					            #     torch.set_printoptions(threshold=10000)
 | 
				
			||||||
 | 
					            #     print('idx', idx)
 | 
				
			||||||
 | 
					            #     exit(0)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            idx, targets, mask = batch
 | 
				
			||||||
 | 
					            mask = mask.view(-1)
 | 
				
			||||||
 | 
					            sum_mask = torch.sum(mask).item()
 | 
				
			||||||
 | 
					            # if sum_mask == 0:
 | 
				
			||||||
 | 
					            #     return torch.tensor([0.0], requires_grad=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            logits = self(idx)
 | 
				
			||||||
 | 
					            if sum_mask == mask.shape[0]:
 | 
				
			||||||
 | 
					                loss = F.cross_entropy(
 | 
				
			||||||
 | 
					                    logits.view(-1, logits.size(-1)), targets.view(-1)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                # print('rank', self.global_rank, 'loss', loss.item())
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                loss = F.cross_entropy(
 | 
				
			||||||
 | 
					                    logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                # loss_raw = loss
 | 
				
			||||||
 | 
					                loss = torch.sum(loss * mask) / sum_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # torch.set_printoptions(threshold=10000)
 | 
				
			||||||
 | 
					                # if True: #self.global_rank == 1:
 | 
				
			||||||
 | 
					                #     tmp = ''
 | 
				
			||||||
 | 
					                #     sss = 0
 | 
				
			||||||
 | 
					                #     ccc = 0
 | 
				
			||||||
 | 
					                #     for i in range(mask.shape[0]):
 | 
				
			||||||
 | 
					                #         if mask[i] > 0:
 | 
				
			||||||
 | 
					                #             tmp += str(idx.view(-1)[i].item()) + ','
 | 
				
			||||||
 | 
					                #             sss += loss_raw.view(-1)[i].float().item()
 | 
				
			||||||
 | 
					                #             ccc += 1
 | 
				
			||||||
 | 
					                #     print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
 | 
				
			||||||
 | 
					        return L2Wrap.apply(loss, logits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step_end(self, batch_parts):
 | 
				
			||||||
 | 
					        if pl.__version__[0] != "2":
 | 
				
			||||||
 | 
					            all = self.all_gather(batch_parts)
 | 
				
			||||||
 | 
					            if self.trainer.is_global_zero:
 | 
				
			||||||
 | 
					                self.trainer.my_loss_all = all
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_init_weight(self):
 | 
				
			||||||
 | 
					        print(
 | 
				
			||||||
 | 
					            f"""
 | 
				
			||||||
 | 
					############################################################################
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Init model weight (slow for large models)...
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					############################################################################
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        m = {}
 | 
				
			||||||
 | 
					        for n in self.state_dict():
 | 
				
			||||||
 | 
					            p = self.state_dict()[n]
 | 
				
			||||||
 | 
					            shape = p.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            gain = 1.0
 | 
				
			||||||
 | 
					            scale = 1.0
 | 
				
			||||||
 | 
					            if (
 | 
				
			||||||
 | 
					                "ln_" in n
 | 
				
			||||||
 | 
					                or ".ln" in n
 | 
				
			||||||
 | 
					                or "time_" in n
 | 
				
			||||||
 | 
					                or "_mask" in n
 | 
				
			||||||
 | 
					                or "pos_emb" in n
 | 
				
			||||||
 | 
					                or ".mask." in n
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					                if "ln_x.weight" in n:
 | 
				
			||||||
 | 
					                    layer_scale = (1 + int(n.split(".")[1])) / self.args.n_layer
 | 
				
			||||||
 | 
					                    m[n] = (p * 0.0) + (layer_scale**0.7)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    m[n] = p
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                if n == "emb.weight":
 | 
				
			||||||
 | 
					                    scale = -1 * self.args.lr_init
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    if shape[0] > shape[1]:
 | 
				
			||||||
 | 
					                        gain = math.sqrt(shape[0] / shape[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    zero = [
 | 
				
			||||||
 | 
					                        ".att.output.",
 | 
				
			||||||
 | 
					                        ".ffn.value.",
 | 
				
			||||||
 | 
					                        ".ffn.receptance.",
 | 
				
			||||||
 | 
					                        ".ffnPre.value.",
 | 
				
			||||||
 | 
					                        ".ffnPre.receptance.",
 | 
				
			||||||
 | 
					                        "head_q.",
 | 
				
			||||||
 | 
					                        ".oo.",
 | 
				
			||||||
 | 
					                        ".rr.",
 | 
				
			||||||
 | 
					                    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    for kk in zero:
 | 
				
			||||||
 | 
					                        if kk in n:
 | 
				
			||||||
 | 
					                            scale = 0
 | 
				
			||||||
 | 
					                    if n == "head.weight":
 | 
				
			||||||
 | 
					                        scale = 0.5
 | 
				
			||||||
 | 
					                    if "head_k." in n:
 | 
				
			||||||
 | 
					                        scale = 0.1
 | 
				
			||||||
 | 
					                    if "head_q." in n:
 | 
				
			||||||
 | 
					                        scale = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                print(
 | 
				
			||||||
 | 
					                    f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if self.args.accelerator.upper() == "GPU":
 | 
				
			||||||
 | 
					                    m[n] = torch.empty((shape[0], shape[1]), device="cuda")
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    m[n] = torch.empty((shape[0], shape[1]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if scale == 0:
 | 
				
			||||||
 | 
					                    nn.init.zeros_(m[n])
 | 
				
			||||||
 | 
					                elif scale < 0:
 | 
				
			||||||
 | 
					                    nn.init.uniform_(m[n], a=scale, b=-scale)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    nn.init.orthogonal_(m[n], gain=gain * scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            m[n] = m[n].cpu()
 | 
				
			||||||
 | 
					            if os.environ["RWKV_FLOAT_MODE"] == "fp16":
 | 
				
			||||||
 | 
					                m[n] = m[n].half()
 | 
				
			||||||
 | 
					            elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
 | 
				
			||||||
 | 
					                m[n] = m[n].bfloat16()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # if n == "emb.weight":
 | 
				
			||||||
 | 
					            #     print(m[n])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        gc.collect()
 | 
				
			||||||
 | 
					        torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					        return m
 | 
				
			||||||
							
								
								
									
										310
									
								
								finetune/lora/v5/src/trainer.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										310
									
								
								finetune/lora/v5/src/trainer.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,310 @@
 | 
				
			|||||||
 | 
					import os, math, time, datetime, subprocess
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
 | 
				
			||||||
 | 
					from .model import LORA_CONFIG
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def my_save(args, trainer, dd, ff):
 | 
				
			||||||
 | 
					    if "14b-run1" in ff:
 | 
				
			||||||
 | 
					        fn = ff.split("/")[-1]
 | 
				
			||||||
 | 
					        fff = "/dev/shm/" + fn
 | 
				
			||||||
 | 
					        torch.save(dd, fff)
 | 
				
			||||||
 | 
					        subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
 | 
				
			||||||
 | 
					    elif ("world/14b" in ff) or ("world/7b" in ff):
 | 
				
			||||||
 | 
					        aa = ff.split("/")[1]
 | 
				
			||||||
 | 
					        fn = ff.split("/")[-1]
 | 
				
			||||||
 | 
					        fff = f"/dev/shm/{aa}-{fn}"
 | 
				
			||||||
 | 
					        torch.save(dd, fff)
 | 
				
			||||||
 | 
					        subprocess.Popen(
 | 
				
			||||||
 | 
					            f" aws s3 mv {fff} s3://rwkv-world/{aa}-{fn} --quiet", shell=True
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if "deepspeed_stage_3" in args.strategy:
 | 
				
			||||||
 | 
					            trainer.save_checkpoint(ff, weights_only=True)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            torch.save(dd, ff)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class train_callback(pl.Callback):
 | 
				
			||||||
 | 
					    def __init__(self, args):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.args = args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
 | 
				
			||||||
 | 
					        args = self.args
 | 
				
			||||||
 | 
					        # if args.cuda_cleanup > 0:
 | 
				
			||||||
 | 
					        #     torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					        real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # LR schedule
 | 
				
			||||||
 | 
					        w_step = args.warmup_steps
 | 
				
			||||||
 | 
					        if args.lr_final == args.lr_init or args.epoch_count == 0:
 | 
				
			||||||
 | 
					            lr = args.lr_init
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            decay_step = real_step - args.my_pile_edecay * args.epoch_steps
 | 
				
			||||||
 | 
					            decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
 | 
				
			||||||
 | 
					            progress = (decay_step - w_step + 1) / (decay_total - w_step)
 | 
				
			||||||
 | 
					            progress = min(1, max(0, progress))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if args.lr_final == 0 or args.lr_init == 0:  # linear decay
 | 
				
			||||||
 | 
					                lr = args.lr_init + (args.lr_final - args.lr_init) * progress
 | 
				
			||||||
 | 
					            else:  # exp decay
 | 
				
			||||||
 | 
					                lr = args.lr_init * math.exp(
 | 
				
			||||||
 | 
					                    math.log(args.lr_final / args.lr_init) * pow(progress, 1)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            # if trainer.is_global_zero:
 | 
				
			||||||
 | 
					            #     print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.my_exit_tokens != 0:  # cosine decay
 | 
				
			||||||
 | 
					            real_tokens = real_step * args.ctx_len * args.real_bsz
 | 
				
			||||||
 | 
					            warmup_tokens = w_step * args.ctx_len * args.real_bsz
 | 
				
			||||||
 | 
					            progress = (real_tokens - warmup_tokens) / (
 | 
				
			||||||
 | 
					                abs(args.my_exit_tokens) - warmup_tokens
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            progress = max(0, min(1, progress))
 | 
				
			||||||
 | 
					            lr_final_factor = args.lr_final / args.lr_init
 | 
				
			||||||
 | 
					            lr_mult = (0.5 + lr_final_factor / 2) + (
 | 
				
			||||||
 | 
					                0.5 - lr_final_factor / 2
 | 
				
			||||||
 | 
					            ) * math.cos(math.pi * progress)
 | 
				
			||||||
 | 
					            if args.my_exit_tokens > 0:
 | 
				
			||||||
 | 
					                lr = args.lr_init * lr_mult
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                lr = (lr + args.lr_init * lr_mult) / 2
 | 
				
			||||||
 | 
					            if progress >= 1:
 | 
				
			||||||
 | 
					                if (trainer.is_global_zero) or ("deepspeed_stage_3" in args.strategy):
 | 
				
			||||||
 | 
					                    my_save(
 | 
				
			||||||
 | 
					                        args,
 | 
				
			||||||
 | 
					                        trainer,
 | 
				
			||||||
 | 
					                        pl_module.state_dict(),
 | 
				
			||||||
 | 
					                        f"{args.proj_dir}/rwkv-final.pth",
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    exit(0)
 | 
				
			||||||
 | 
					        if trainer.global_step < w_step:
 | 
				
			||||||
 | 
					            lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.weight_decay_final > 0:
 | 
				
			||||||
 | 
					            wd_now = args.weight_decay * math.exp(
 | 
				
			||||||
 | 
					                math.log(args.weight_decay_final / args.weight_decay) * progress
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            wd_now = args.weight_decay
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for param_group in trainer.optimizers[0].param_groups:
 | 
				
			||||||
 | 
					            if param_group["weight_decay"] > 0:
 | 
				
			||||||
 | 
					                param_group["weight_decay"] = wd_now
 | 
				
			||||||
 | 
					            if args.layerwise_lr > 0:
 | 
				
			||||||
 | 
					                param_group["lr"] = lr * param_group["my_lr_scale"]
 | 
				
			||||||
 | 
					                # print(param_group["lr"], param_group["my_lr_scale"])
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                param_group["lr"] = lr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        trainer.my_lr = lr
 | 
				
			||||||
 | 
					        trainer.my_wd = wd_now
 | 
				
			||||||
 | 
					        # rank_zero_info(f"{real_step} {lr}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if trainer.global_step == 0:
 | 
				
			||||||
 | 
					            if trainer.is_global_zero:  # logging
 | 
				
			||||||
 | 
					                trainer.my_loss_sum = 0
 | 
				
			||||||
 | 
					                trainer.my_loss_count = 0
 | 
				
			||||||
 | 
					                trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
 | 
				
			||||||
 | 
					                trainer.my_log.write(
 | 
				
			||||||
 | 
					                    f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    print(f"\n{trainer.strategy.config}\n")
 | 
				
			||||||
 | 
					                    trainer.my_log.write(f"{trainer.strategy.config}\n")
 | 
				
			||||||
 | 
					                except:
 | 
				
			||||||
 | 
					                    pass
 | 
				
			||||||
 | 
					                trainer.my_log.flush()
 | 
				
			||||||
 | 
					                if len(args.wandb) > 0:
 | 
				
			||||||
 | 
					                    print("Login to wandb...")
 | 
				
			||||||
 | 
					                    import wandb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    wandb.init(
 | 
				
			||||||
 | 
					                        project=args.wandb,
 | 
				
			||||||
 | 
					                        name=args.run_name + " " + args.my_timestamp,
 | 
				
			||||||
 | 
					                        config=args,
 | 
				
			||||||
 | 
					                        save_code=False,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    trainer.my_wandb = wandb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
 | 
				
			||||||
 | 
					        args = self.args
 | 
				
			||||||
 | 
					        token_per_step = args.ctx_len * args.real_bsz
 | 
				
			||||||
 | 
					        real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
 | 
				
			||||||
 | 
					        if trainer.is_global_zero:  # logging
 | 
				
			||||||
 | 
					            t_now = time.time_ns()
 | 
				
			||||||
 | 
					            kt_s = 0
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                t_cost = (t_now - trainer.my_time_ns) / 1e9
 | 
				
			||||||
 | 
					                kt_s = token_per_step / t_cost / 1000
 | 
				
			||||||
 | 
					                self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
 | 
				
			||||||
 | 
					                self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
 | 
				
			||||||
 | 
					            except:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
 | 
					            trainer.my_time_ns = t_now
 | 
				
			||||||
 | 
					            if pl.__version__[0] == "2":
 | 
				
			||||||
 | 
					                trainer.my_loss = outputs["loss"]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                trainer.my_loss = trainer.my_loss_all.float().mean().item()
 | 
				
			||||||
 | 
					            trainer.my_loss_sum += trainer.my_loss
 | 
				
			||||||
 | 
					            trainer.my_loss_count += 1
 | 
				
			||||||
 | 
					            trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
 | 
				
			||||||
 | 
					            self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
 | 
				
			||||||
 | 
					            self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
 | 
				
			||||||
 | 
					            # self.log("s", real_step, prog_bar=True, on_step=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if len(args.wandb) > 0:
 | 
				
			||||||
 | 
					                lll = {
 | 
				
			||||||
 | 
					                    "loss": trainer.my_loss,
 | 
				
			||||||
 | 
					                    "lr": trainer.my_lr,
 | 
				
			||||||
 | 
					                    "wd": trainer.my_wd,
 | 
				
			||||||
 | 
					                    "Gtokens": real_step * token_per_step / 1e9,
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                if kt_s > 0:
 | 
				
			||||||
 | 
					                    lll["kt/s"] = kt_s
 | 
				
			||||||
 | 
					                trainer.my_wandb.log(lll, step=int(real_step))
 | 
				
			||||||
 | 
					        if (trainer.is_global_zero) or (
 | 
				
			||||||
 | 
					            "deepspeed_stage_3" in args.strategy
 | 
				
			||||||
 | 
					        ):  # save pth
 | 
				
			||||||
 | 
					            if args.magic_prime > 0:
 | 
				
			||||||
 | 
					                expand_factor = 2 if args.my_qa_mask > 0 else 1
 | 
				
			||||||
 | 
					                if int(real_step) == int(
 | 
				
			||||||
 | 
					                    args.magic_prime * expand_factor // args.real_bsz
 | 
				
			||||||
 | 
					                ) - 1 + int(args.my_random_steps):
 | 
				
			||||||
 | 
					                    to_save_dict = pl_module.state_dict()
 | 
				
			||||||
 | 
					                    my_save(
 | 
				
			||||||
 | 
					                        args,
 | 
				
			||||||
 | 
					                        trainer,
 | 
				
			||||||
 | 
					                        to_save_dict,
 | 
				
			||||||
 | 
					                        f"{args.proj_dir}/rwkv-final.pth",
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					        # if args.batch_save==batch_idx :
 | 
				
			||||||
 | 
					        #     to_save_dict = pl_module.state_dict()
 | 
				
			||||||
 | 
					        #     for name, state in to_save_dict.items():
 | 
				
			||||||
 | 
					        #         if 'img' in name:
 | 
				
			||||||
 | 
					        #             to_save_dict[name] = state
 | 
				
			||||||
 | 
					        #     try:
 | 
				
			||||||
 | 
					        #             my_save(
 | 
				
			||||||
 | 
					        #                 args, trainer,
 | 
				
			||||||
 | 
					        #                 to_save_dict,
 | 
				
			||||||
 | 
					        #                 f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}-{batch_idx}.pth",
 | 
				
			||||||
 | 
					        #             )
 | 
				
			||||||
 | 
					        #     except Exception as e:
 | 
				
			||||||
 | 
					        #         print('Error\n\n', e, '\n\n')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_train_epoch_start(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        args = self.args
 | 
				
			||||||
 | 
					        if pl.__version__[0] == "2":
 | 
				
			||||||
 | 
					            dataset = trainer.train_dataloader.dataset
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            dataset = trainer.train_dataloader.dataset.datasets
 | 
				
			||||||
 | 
					        assert "MyDataset" in str(dataset)
 | 
				
			||||||
 | 
					        dataset.global_rank = trainer.global_rank
 | 
				
			||||||
 | 
					        dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
 | 
				
			||||||
 | 
					        dataset.world_size = trainer.world_size
 | 
				
			||||||
 | 
					        # print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_train_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        args = self.args
 | 
				
			||||||
 | 
					        to_save_dict = {}
 | 
				
			||||||
 | 
					        if (trainer.is_global_zero) or (
 | 
				
			||||||
 | 
					            "deepspeed_stage_3" in args.strategy
 | 
				
			||||||
 | 
					        ):  # save pth
 | 
				
			||||||
 | 
					            if (
 | 
				
			||||||
 | 
					                args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0
 | 
				
			||||||
 | 
					            ) or (trainer.current_epoch == args.epoch_count - 1):
 | 
				
			||||||
 | 
					                if args.data_type == "wds_img":
 | 
				
			||||||
 | 
					                    raw_dict = pl_module.state_dict()
 | 
				
			||||||
 | 
					                    for k in raw_dict:
 | 
				
			||||||
 | 
					                        if k.startswith("encoder.") or k.startswith("decoder."):
 | 
				
			||||||
 | 
					                            to_save_dict[k] = raw_dict[k]
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    to_save_dict = pl_module.state_dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if args.data_type == "img" and not args.lora:
 | 
				
			||||||
 | 
					                    for name, state in to_save_dict.items():
 | 
				
			||||||
 | 
					                        if "img" in name:
 | 
				
			||||||
 | 
					                            to_save_dict[name] = state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if args.lora:
 | 
				
			||||||
 | 
					                    enable_time_finetune = "time" in LORA_CONFIG["parts"]
 | 
				
			||||||
 | 
					                    enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
 | 
				
			||||||
 | 
					                    lora_dict = {}
 | 
				
			||||||
 | 
					                    for name, state in to_save_dict.items():
 | 
				
			||||||
 | 
					                        if "img" in name:
 | 
				
			||||||
 | 
					                            lora_dict[name] = state
 | 
				
			||||||
 | 
					                        if (
 | 
				
			||||||
 | 
					                            ".lora_" in name
 | 
				
			||||||
 | 
					                            or (enable_time_finetune and ".time_" in name)
 | 
				
			||||||
 | 
					                            or (enable_ln_finetune and ".ln" in name)
 | 
				
			||||||
 | 
					                        ):
 | 
				
			||||||
 | 
					                            lora_dict[name] = state
 | 
				
			||||||
 | 
					                    to_save_dict = lora_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    my_save(
 | 
				
			||||||
 | 
					                        args,
 | 
				
			||||||
 | 
					                        trainer,
 | 
				
			||||||
 | 
					                        to_save_dict,
 | 
				
			||||||
 | 
					                        f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                except Exception as e:
 | 
				
			||||||
 | 
					                    print("Error\n\n", e, "\n\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if trainer.is_global_zero:  # logging
 | 
				
			||||||
 | 
					            trainer.my_log.write(
 | 
				
			||||||
 | 
					                f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            trainer.my_log.flush()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            trainer.my_loss_sum = 0
 | 
				
			||||||
 | 
					            trainer.my_loss_count = 0
 | 
				
			||||||
 | 
					            if (args.epoch_begin + trainer.current_epoch) >= args.my_exit:
 | 
				
			||||||
 | 
					                exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@rank_zero_only
 | 
				
			||||||
 | 
					def generate_init_weight(model, init_weight_name):
 | 
				
			||||||
 | 
					    mm = model.generate_init_weight()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if model.args.my_pile_stage == 1:
 | 
				
			||||||
 | 
					        if len(model.args.load_model) > 0:
 | 
				
			||||||
 | 
					            print(f"Combine weights from {model.args.load_model}...")
 | 
				
			||||||
 | 
					            load_dict = torch.load(model.args.load_model, map_location="cpu")
 | 
				
			||||||
 | 
					            for k in load_dict:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    assert k in mm
 | 
				
			||||||
 | 
					                except:
 | 
				
			||||||
 | 
					                    print("missing", k)
 | 
				
			||||||
 | 
					                    exit(0)
 | 
				
			||||||
 | 
					                src = load_dict[k]
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    mm[k] = src.reshape(mm[k].shape)
 | 
				
			||||||
 | 
					                except:
 | 
				
			||||||
 | 
					                    tmp = mm[k].squeeze().clone()
 | 
				
			||||||
 | 
					                    print(k, src.shape, "-->", mm[k].shape)
 | 
				
			||||||
 | 
					                    ss = src.shape[0]
 | 
				
			||||||
 | 
					                    dd = tmp.shape[0]
 | 
				
			||||||
 | 
					                    for i in range(dd):
 | 
				
			||||||
 | 
					                        pos = i / dd * ss
 | 
				
			||||||
 | 
					                        if pos >= ss - 1:
 | 
				
			||||||
 | 
					                            tmp[i] = src[ss - 1]
 | 
				
			||||||
 | 
					                        else:
 | 
				
			||||||
 | 
					                            p0 = int(math.floor(pos))
 | 
				
			||||||
 | 
					                            ii = pos - p0
 | 
				
			||||||
 | 
					                            tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii)
 | 
				
			||||||
 | 
					                    mm[k] = tmp.reshape(mm[k].shape)
 | 
				
			||||||
 | 
					                    sss = src.squeeze().float().cpu().numpy()
 | 
				
			||||||
 | 
					                    print(sss[:10], "...", sss[-10:])
 | 
				
			||||||
 | 
					                    mmm = mm[k].squeeze().float().cpu().numpy()
 | 
				
			||||||
 | 
					                    print(mmm[:10], "...", mmm[-10:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(f"Save to {init_weight_name}...")
 | 
				
			||||||
 | 
					    torch.save(mm, init_weight_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if model.args.my_pile_stage == 1:
 | 
				
			||||||
 | 
					        print("Done. Now go for stage 2.")
 | 
				
			||||||
 | 
					        exit(0)
 | 
				
			||||||
							
								
								
									
										139
									
								
								finetune/lora/v5/src/utils.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								finetune/lora/v5/src/utils.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,139 @@
 | 
				
			|||||||
 | 
					import json, time, random, os
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.nn import functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					time_slot = {}
 | 
				
			||||||
 | 
					time_ref = time.time_ns()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def record_time(name):
 | 
				
			||||||
 | 
					    if name not in time_slot:
 | 
				
			||||||
 | 
					        time_slot[name] = 1e20
 | 
				
			||||||
 | 
					    tt = (time.time_ns() - time_ref) / 1e9
 | 
				
			||||||
 | 
					    if tt < time_slot[name]:
 | 
				
			||||||
 | 
					        time_slot[name] = tt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TOKENIZER:
 | 
				
			||||||
 | 
					    def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"):
 | 
				
			||||||
 | 
					        if "list" in str(type(WORD_NAME)):
 | 
				
			||||||
 | 
					            self.charMode = False
 | 
				
			||||||
 | 
					            if WORD_NAME[0] == WORD_NAME[1]:
 | 
				
			||||||
 | 
					                from transformers import PreTrainedTokenizerFast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                from transformers import GPT2TokenizerFast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
 | 
				
			||||||
 | 
					            self.vocab_size = len(self.tokenizer)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.charMode = True
 | 
				
			||||||
 | 
					            with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file:
 | 
				
			||||||
 | 
					                self.word_table = json.load(result_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.vocab_size = len(self.word_table)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.stoi = {v: int(k) for k, v in self.word_table.items()}
 | 
				
			||||||
 | 
					            self.itos = {int(k): v for k, v in self.word_table.items()}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def refine_context(self, context):
 | 
				
			||||||
 | 
					        context = context.strip().split("\n")
 | 
				
			||||||
 | 
					        for c in range(len(context)):
 | 
				
			||||||
 | 
					            context[c] = context[c].strip().strip("\u3000").strip("\r")
 | 
				
			||||||
 | 
					        context = list(filter(lambda c: c != "", context))
 | 
				
			||||||
 | 
					        context = "\n" + ("\n".join(context)).strip()
 | 
				
			||||||
 | 
					        if context == "":
 | 
				
			||||||
 | 
					            context = "\n"
 | 
				
			||||||
 | 
					        return context
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sample_logits(
 | 
				
			||||||
 | 
					        self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        # out[self.UNKNOWN_CHAR] = -float('Inf')
 | 
				
			||||||
 | 
					        lastChar = int(x[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        probs = F.softmax(out, dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.charMode:
 | 
				
			||||||
 | 
					            if self.itos[lastChar] == "\n":
 | 
				
			||||||
 | 
					                top_p = top_p_newline
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                top_p = top_p_usual
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            top_p = top_p_usual
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if os.environ["RWKV_RUN_DEVICE"] == "cpu":
 | 
				
			||||||
 | 
					            probs = probs.numpy()
 | 
				
			||||||
 | 
					            sorted_probs = np.sort(probs)[::-1]
 | 
				
			||||||
 | 
					            cumulative_probs = np.cumsum(sorted_probs)
 | 
				
			||||||
 | 
					            cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
 | 
				
			||||||
 | 
					            probs[probs < cutoff] = 0
 | 
				
			||||||
 | 
					            if temperature != 1.0:
 | 
				
			||||||
 | 
					                probs = probs.pow(1.0 / temperature)
 | 
				
			||||||
 | 
					            probs = probs / np.sum(probs)
 | 
				
			||||||
 | 
					            out = np.random.choice(a=len(probs), p=probs)
 | 
				
			||||||
 | 
					            return out
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            sorted_probs = torch.sort(probs, descending=True)[0]
 | 
				
			||||||
 | 
					            cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
 | 
				
			||||||
 | 
					            cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
 | 
				
			||||||
 | 
					            probs[probs < cutoff] = 0
 | 
				
			||||||
 | 
					            if temperature != 1.0:
 | 
				
			||||||
 | 
					                probs = probs.pow(1.0 / temperature)
 | 
				
			||||||
 | 
					            out = torch.multinomial(probs, num_samples=1)[0]
 | 
				
			||||||
 | 
					            return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def MaybeIsPrime(number):
 | 
				
			||||||
 | 
					    if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def FermatPrimalityTest(number):
 | 
				
			||||||
 | 
					    if number > 1:
 | 
				
			||||||
 | 
					        for time in range(3):
 | 
				
			||||||
 | 
					            randomNumber = random.randint(2, number) - 1
 | 
				
			||||||
 | 
					            if pow(randomNumber, number - 1, number) != 1:
 | 
				
			||||||
 | 
					                return False
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def MillerRabinPrimalityTest(number):
 | 
				
			||||||
 | 
					    if number == 2:
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					    elif number == 1 or number % 2 == 0:
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    oddPartOfNumber = number - 1
 | 
				
			||||||
 | 
					    timesTwoDividNumber = 0
 | 
				
			||||||
 | 
					    while oddPartOfNumber % 2 == 0:
 | 
				
			||||||
 | 
					        oddPartOfNumber = oddPartOfNumber // 2
 | 
				
			||||||
 | 
					        timesTwoDividNumber = timesTwoDividNumber + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for time in range(3):
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            randomNumber = random.randint(2, number) - 1
 | 
				
			||||||
 | 
					            if randomNumber != 0 and randomNumber != 1:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
 | 
				
			||||||
 | 
					            iterationNumber = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            while (iterationNumber <= timesTwoDividNumber - 1) and (
 | 
				
			||||||
 | 
					                randomNumberWithPower != number - 1
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					                randomNumberWithPower = pow(randomNumberWithPower, 2, number)
 | 
				
			||||||
 | 
					                iterationNumber = iterationNumber + 1
 | 
				
			||||||
 | 
					            if randomNumberWithPower != (number - 1):
 | 
				
			||||||
 | 
					                return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
							
								
								
									
										436
									
								
								finetune/lora/v5/train.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										436
									
								
								finetune/lora/v5/train.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,436 @@
 | 
				
			|||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
 | 
				
			||||||
 | 
					########################################################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logging.basicConfig(level=logging.INFO)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    from argparse import ArgumentParser
 | 
				
			||||||
 | 
					    from pytorch_lightning import Trainer
 | 
				
			||||||
 | 
					    from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
 | 
				
			||||||
 | 
					    import pytorch_lightning as pl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rank_zero_info("########## work in progress ##########")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser = ArgumentParser()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument("--load_model", default="", type=str)  # full path, with .pth
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--wandb", default="", type=str
 | 
				
			||||||
 | 
					    )  # wandb project name. if "" then don't use wandb
 | 
				
			||||||
 | 
					    parser.add_argument("--proj_dir", default="out", type=str)
 | 
				
			||||||
 | 
					    parser.add_argument("--random_seed", default="-1", type=int)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument("--data_file", default="", type=str)
 | 
				
			||||||
 | 
					    parser.add_argument("--data_type", default="utf-8", type=str)
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--vocab_size", default=0, type=int
 | 
				
			||||||
 | 
					    )  # vocab_size = 0 means auto (for char-level LM and .txt data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument("--ctx_len", default=1024, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--epoch_steps", default=1000, type=int
 | 
				
			||||||
 | 
					    )  # a mini "epoch" has [epoch_steps] steps
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--epoch_count", default=500, type=int
 | 
				
			||||||
 | 
					    )  # train for this many "epochs". will continue afterwards with lr = lr_final
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--epoch_begin", default=0, type=int
 | 
				
			||||||
 | 
					    )  # if you load a model trained for x "epochs", set epoch_begin = x
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--epoch_save", default=5, type=int
 | 
				
			||||||
 | 
					    )  # save the model every [epoch_save] "epochs"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--micro_bsz", default=12, type=int
 | 
				
			||||||
 | 
					    )  # micro batch size (batch size per GPU)
 | 
				
			||||||
 | 
					    parser.add_argument("--n_layer", default=6, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--n_embd", default=512, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--dim_att", default=0, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--dim_ffn", default=0, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--pre_ffn", default=0, type=int
 | 
				
			||||||
 | 
					    )  # replace first att layer by ffn (sometimes better)
 | 
				
			||||||
 | 
					    parser.add_argument("--head_qk", default=0, type=int)  # my headQK trick
 | 
				
			||||||
 | 
					    parser.add_argument("--tiny_att_dim", default=0, type=int)  # tiny attention dim
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--tiny_att_layer", default=-999, type=int
 | 
				
			||||||
 | 
					    )  # tiny attention @ which layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--lr_init", default=6e-4, type=float
 | 
				
			||||||
 | 
					    )  # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
 | 
				
			||||||
 | 
					    parser.add_argument("--lr_final", default=1e-5, type=float)
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--warmup_steps", default=-1, type=int
 | 
				
			||||||
 | 
					    )  # try 50 if you load a model
 | 
				
			||||||
 | 
					    parser.add_argument("--beta1", default=0.9, type=float)
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--beta2", default=0.99, type=float
 | 
				
			||||||
 | 
					    )  # use 0.999 when your model is close to convergence
 | 
				
			||||||
 | 
					    parser.add_argument("--adam_eps", default=1e-8, type=float)
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--grad_cp", default=0, type=int
 | 
				
			||||||
 | 
					    )  # gradient checkpt: saves VRAM, but slower
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--dropout", default=0, type=float
 | 
				
			||||||
 | 
					    )  # try 0.01 / 0.02 / 0.05 / 0.1
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--weight_decay", default=0, type=float
 | 
				
			||||||
 | 
					    )  # try 0.1 / 0.01 / 0.001
 | 
				
			||||||
 | 
					    parser.add_argument("--weight_decay_final", default=-1, type=float)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--my_pile_version", default=1, type=int
 | 
				
			||||||
 | 
					    )  # my special pile version
 | 
				
			||||||
 | 
					    parser.add_argument("--my_pile_stage", default=0, type=int)  # my special pile mode
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--my_pile_shift", default=-1, type=int
 | 
				
			||||||
 | 
					    )  # my special pile mode - text shift
 | 
				
			||||||
 | 
					    parser.add_argument("--my_pile_edecay", default=0, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--layerwise_lr", default=1, type=int
 | 
				
			||||||
 | 
					    )  # layerwise lr for faster convergence (but slower it/s)
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--ds_bucket_mb", default=200, type=int
 | 
				
			||||||
 | 
					    )  # deepspeed bucket size in MB. 200 seems enough
 | 
				
			||||||
 | 
					    # parser.add_argument("--cuda_cleanup", default=0, type=int)  # extra cuda cleanup (sometimes helpful)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    parser.add_argument("--my_sample_len", default=0, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--my_ffn_shift", default=1, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--my_att_shift", default=1, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--head_size_a", default=64, type=int
 | 
				
			||||||
 | 
					    )  # can try larger values for larger models
 | 
				
			||||||
 | 
					    parser.add_argument("--head_size_divisor", default=8, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--my_pos_emb", default=0, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--load_partial", default=0, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--magic_prime", default=0, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--my_qa_mask", default=0, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--my_random_steps", default=0, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--my_testing", default="", type=str)
 | 
				
			||||||
 | 
					    parser.add_argument("--my_exit", default=99999999, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--my_exit_tokens", default=0, type=int)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # LORA
 | 
				
			||||||
 | 
					    parser.add_argument("--emb", action="store_true")
 | 
				
			||||||
 | 
					    parser.add_argument("--lora", action="store_true")
 | 
				
			||||||
 | 
					    parser.add_argument("--lora_load", default="", type=str)
 | 
				
			||||||
 | 
					    parser.add_argument("--lora_r", default=8, type=int)
 | 
				
			||||||
 | 
					    parser.add_argument("--lora_alpha", default=32, type=float)
 | 
				
			||||||
 | 
					    parser.add_argument("--lora_dropout", default=0.01, type=float)
 | 
				
			||||||
 | 
					    parser.add_argument("--lora_parts", default="att,ln,time", type=str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if pl.__version__[0] == "2":
 | 
				
			||||||
 | 
					        parser.add_argument("--accelerator", default="gpu", type=str)
 | 
				
			||||||
 | 
					        parser.add_argument("--strategy", default="auto", type=str)
 | 
				
			||||||
 | 
					        parser.add_argument("--devices", default=1, type=int)
 | 
				
			||||||
 | 
					        parser.add_argument("--num_nodes", default=1, type=int)
 | 
				
			||||||
 | 
					        parser.add_argument("--precision", default="fp16", type=str)
 | 
				
			||||||
 | 
					        parser.add_argument("--accumulate_grad_batches", default=1, type=int)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        parser = Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ########################################################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    import os, warnings, math, datetime, sys, time
 | 
				
			||||||
 | 
					    import numpy as np
 | 
				
			||||||
 | 
					    import torch
 | 
				
			||||||
 | 
					    from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if "deepspeed" in args.strategy:
 | 
				
			||||||
 | 
					        import deepspeed
 | 
				
			||||||
 | 
					    from pytorch_lightning import seed_everything
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.random_seed >= 0:
 | 
				
			||||||
 | 
					        print(
 | 
				
			||||||
 | 
					            f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n"
 | 
				
			||||||
 | 
					            * 3
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        seed_everything(args.random_seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    np.set_printoptions(precision=4, suppress=True, linewidth=200)
 | 
				
			||||||
 | 
					    warnings.filterwarnings(
 | 
				
			||||||
 | 
					        "ignore", ".*Consider increasing the value of the `num_workers` argument*"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    warnings.filterwarnings(
 | 
				
			||||||
 | 
					        "ignore", ".*The progress bar already tracks a metric with the*"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    # os.environ["WDS_SHOW_SEED"] = "1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
 | 
				
			||||||
 | 
					    args.enable_checkpointing = False
 | 
				
			||||||
 | 
					    args.replace_sampler_ddp = False
 | 
				
			||||||
 | 
					    args.logger = False
 | 
				
			||||||
 | 
					    args.gradient_clip_val = 1.0
 | 
				
			||||||
 | 
					    args.num_sanity_val_steps = 0
 | 
				
			||||||
 | 
					    args.check_val_every_n_epoch = int(1e20)
 | 
				
			||||||
 | 
					    args.log_every_n_steps = int(1e20)
 | 
				
			||||||
 | 
					    args.max_epochs = args.epoch_count  # -1 continue forever
 | 
				
			||||||
 | 
					    args.betas = (args.beta1, args.beta2)
 | 
				
			||||||
 | 
					    args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
 | 
				
			||||||
 | 
					    os.environ["RWKV_MY_TESTING"] = args.my_testing
 | 
				
			||||||
 | 
					    os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
 | 
				
			||||||
 | 
					    if args.dim_att <= 0:
 | 
				
			||||||
 | 
					        args.dim_att = args.n_embd
 | 
				
			||||||
 | 
					    if args.dim_ffn <= 0:
 | 
				
			||||||
 | 
					        args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)  # default = 3.5x emb size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.data_type == "wds_img":
 | 
				
			||||||
 | 
					        args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
 | 
				
			||||||
 | 
					        args.proj_dir = f"{args.proj_dir}-{args.run_name}"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        args.run_name = (
 | 
				
			||||||
 | 
					            f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    if not os.path.exists(args.proj_dir):
 | 
				
			||||||
 | 
					        os.makedirs(args.proj_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.my_pile_stage > 0:
 | 
				
			||||||
 | 
					        magic_prime_bak = args.magic_prime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if args.my_pile_shift < 0:
 | 
				
			||||||
 | 
					            args.my_pile_shift = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if magic_prime_bak > 0:
 | 
				
			||||||
 | 
					            args.magic_prime = magic_prime_bak
 | 
				
			||||||
 | 
					        if args.my_qa_mask == 2:
 | 
				
			||||||
 | 
					            args.epoch_count = 2 * args.magic_prime // 40320
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            args.epoch_count = args.magic_prime // 40320
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        args.epoch_steps = 40320 // args.real_bsz
 | 
				
			||||||
 | 
					        assert args.epoch_steps * args.real_bsz == 40320
 | 
				
			||||||
 | 
					        # if args.my_pile_stage == 2:
 | 
				
			||||||
 | 
					        #     assert args.lr_final == args.lr_init
 | 
				
			||||||
 | 
					        if args.my_pile_stage >= 2:  # find latest saved model
 | 
				
			||||||
 | 
					            list_p = []
 | 
				
			||||||
 | 
					            for p in os.listdir(args.proj_dir):
 | 
				
			||||||
 | 
					                if p.startswith("rwkv") and p.endswith(".pth"):
 | 
				
			||||||
 | 
					                    p = ((p.split("-"))[1].split("."))[0]
 | 
				
			||||||
 | 
					                    if p != "final":
 | 
				
			||||||
 | 
					                        if p == "init":
 | 
				
			||||||
 | 
					                            p = -1
 | 
				
			||||||
 | 
					                        else:
 | 
				
			||||||
 | 
					                            p = int(p)
 | 
				
			||||||
 | 
					                        list_p += [p]
 | 
				
			||||||
 | 
					            list_p.sort()
 | 
				
			||||||
 | 
					            max_p = list_p[-1]
 | 
				
			||||||
 | 
					            if len(list_p) > 1:
 | 
				
			||||||
 | 
					                args.my_pile_prev_p = list_p[-2]  # in case max_p is corrupted
 | 
				
			||||||
 | 
					            if max_p == -1:
 | 
				
			||||||
 | 
					                args.load_model = f"{args.proj_dir}/rwkv-init.pth"
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
 | 
				
			||||||
 | 
					                if args.warmup_steps < 0:
 | 
				
			||||||
 | 
					                    if args.my_pile_stage == 2:
 | 
				
			||||||
 | 
					                        args.warmup_steps = 10
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        args.warmup_steps = 30
 | 
				
			||||||
 | 
					            args.epoch_begin = max_p + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    samples_per_epoch = args.epoch_steps * args.real_bsz
 | 
				
			||||||
 | 
					    tokens_per_epoch = samples_per_epoch * args.ctx_len
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        deepspeed_version = deepspeed.__version__
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        deepspeed_version = None
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					    rank_zero_info(
 | 
				
			||||||
 | 
					        f"""
 | 
				
			||||||
 | 
					############################################################################
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# RWKV-5 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1}, save every {args.epoch_save} epoch
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
 | 
				
			||||||
 | 
					# Found deepspeed {deepspeed_version}, recommend 0.7.0 (faster than newer versions)
 | 
				
			||||||
 | 
					# Found pytorch_lightning {pl.__version__}, recommend 1.9.5
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					############################################################################
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    rank_zero_info(str(vars(args)) + "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "uint16"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.lr_final == 0 or args.lr_init == 0:
 | 
				
			||||||
 | 
					        rank_zero_info(
 | 
				
			||||||
 | 
					            "\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
 | 
				
			||||||
 | 
					    os.environ["RWKV_FLOAT_MODE"] = args.precision
 | 
				
			||||||
 | 
					    if args.precision == "fp32":
 | 
				
			||||||
 | 
					        for i in range(10):
 | 
				
			||||||
 | 
					            rank_zero_info(
 | 
				
			||||||
 | 
					                "\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					    if args.precision == "fp16":
 | 
				
			||||||
 | 
					        rank_zero_info(
 | 
				
			||||||
 | 
					            "\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    os.environ["RWKV_JIT_ON"] = "0"
 | 
				
			||||||
 | 
					    if "deepspeed_stage_3" in args.strategy:
 | 
				
			||||||
 | 
					        os.environ["RWKV_JIT_ON"] = "0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    torch.backends.cudnn.benchmark = True
 | 
				
			||||||
 | 
					    torch.backends.cudnn.enabled = True
 | 
				
			||||||
 | 
					    if args.precision == "fp32":
 | 
				
			||||||
 | 
					        torch.backends.cudnn.allow_tf32 = False
 | 
				
			||||||
 | 
					        torch.backends.cuda.matmul.allow_tf32 = False
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        torch.backends.cudnn.allow_tf32 = True
 | 
				
			||||||
 | 
					        torch.backends.cuda.matmul.allow_tf32 = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if "32" in args.precision:
 | 
				
			||||||
 | 
					        args.precision = 32
 | 
				
			||||||
 | 
					    elif args.precision == "fp16":
 | 
				
			||||||
 | 
					        args.precision = 16
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        args.precision = "bf16"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ########################################################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from src.trainer import train_callback, generate_init_weight
 | 
				
			||||||
 | 
					    from src.dataset import MyDataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_data = MyDataset(args)
 | 
				
			||||||
 | 
					    args.vocab_size = train_data.vocab_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from src.model import RWKV, LORA_CONFIG, LoraLinear
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.lora:
 | 
				
			||||||
 | 
					        assert args.lora_r > 0, "LoRA should have its `r` > 0"
 | 
				
			||||||
 | 
					        LORA_CONFIG["r"] = args.lora_r
 | 
				
			||||||
 | 
					        LORA_CONFIG["alpha"] = args.lora_alpha
 | 
				
			||||||
 | 
					        LORA_CONFIG["dropout"] = args.lora_dropout
 | 
				
			||||||
 | 
					        LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
 | 
				
			||||||
 | 
					        enable_time_finetune = "time" in LORA_CONFIG["parts"]
 | 
				
			||||||
 | 
					        enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
 | 
				
			||||||
 | 
					    model = RWKV(args)
 | 
				
			||||||
 | 
					    # only train lora parameters
 | 
				
			||||||
 | 
					    if args.lora:
 | 
				
			||||||
 | 
					        model.requires_grad_(False)
 | 
				
			||||||
 | 
					        for name, module in model.named_modules():
 | 
				
			||||||
 | 
					            if any(n.startswith("lora_") for n, _ in module.named_parameters()):
 | 
				
			||||||
 | 
					                print(f"  LoRA additionally training module {name}")
 | 
				
			||||||
 | 
					                for pname, param in module.named_parameters():
 | 
				
			||||||
 | 
					                    param.requires_grad = "lora_" in pname
 | 
				
			||||||
 | 
					            elif enable_ln_finetune and ".ln" in name:
 | 
				
			||||||
 | 
					                print(f"  LoRA additionally training module {name}")
 | 
				
			||||||
 | 
					                for param in module.parameters():
 | 
				
			||||||
 | 
					                    param.requires_grad = True
 | 
				
			||||||
 | 
					            elif enable_time_finetune and any(
 | 
				
			||||||
 | 
					                n.startswith("time") for n, _ in module.named_parameters()
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					                for pname, param in module.named_parameters():
 | 
				
			||||||
 | 
					                    if pname.startswith("time"):
 | 
				
			||||||
 | 
					                        print(f"  LoRA additionally training parameter {pname}")
 | 
				
			||||||
 | 
					                        param.requires_grad = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (
 | 
				
			||||||
 | 
					        len(args.load_model) == 0 or args.my_pile_stage == 1
 | 
				
			||||||
 | 
					    ):  # shall we build the initial weights?
 | 
				
			||||||
 | 
					        init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
 | 
				
			||||||
 | 
					        generate_init_weight(model, init_weight_name)  # save initial weights
 | 
				
			||||||
 | 
					        args.load_model = init_weight_name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rank_zero_info(f"########## Loading {args.load_model}... ##########")
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        load_dict = torch.load(args.load_model, map_location="cpu")
 | 
				
			||||||
 | 
					        load_keys = list(load_dict.keys())
 | 
				
			||||||
 | 
					        for k in load_keys:
 | 
				
			||||||
 | 
					            if k.startswith("_forward_module."):
 | 
				
			||||||
 | 
					                load_dict[k.replace("_forward_module.", "")] = load_dict[k]
 | 
				
			||||||
 | 
					                del load_dict[k]
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        rank_zero_info(f"Bad checkpoint {args.load_model}")
 | 
				
			||||||
 | 
					        if args.my_pile_stage >= 2:  # try again using another checkpoint
 | 
				
			||||||
 | 
					            max_p = args.my_pile_prev_p
 | 
				
			||||||
 | 
					            if max_p == -1:
 | 
				
			||||||
 | 
					                args.load_model = f"{args.proj_dir}/rwkv-init.pth"
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
 | 
				
			||||||
 | 
					            args.epoch_begin = max_p + 1
 | 
				
			||||||
 | 
					            rank_zero_info(f"Trying {args.load_model}")
 | 
				
			||||||
 | 
					            load_dict = torch.load(args.load_model, map_location="cpu")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if args.load_partial == 1:
 | 
				
			||||||
 | 
					        load_keys = load_dict.keys()
 | 
				
			||||||
 | 
					        for k in model.state_dict():
 | 
				
			||||||
 | 
					            if k not in load_keys:
 | 
				
			||||||
 | 
					                load_dict[k] = model.state_dict()[k]
 | 
				
			||||||
 | 
					    # model.load_state_dict(load_dict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model.load_state_dict(load_dict, strict=(not args.lora))
 | 
				
			||||||
 | 
					    if os.path.isfile(args.lora_load):
 | 
				
			||||||
 | 
					        model.load_state_dict(
 | 
				
			||||||
 | 
					            torch.load(args.lora_load, map_location="cpu"), strict=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if pl.__version__[0] == "2":
 | 
				
			||||||
 | 
					        trainer = Trainer(
 | 
				
			||||||
 | 
					            accelerator=args.accelerator,
 | 
				
			||||||
 | 
					            strategy=args.strategy,
 | 
				
			||||||
 | 
					            devices=args.devices,
 | 
				
			||||||
 | 
					            num_nodes=args.num_nodes,
 | 
				
			||||||
 | 
					            precision=args.precision,
 | 
				
			||||||
 | 
					            logger=args.logger,
 | 
				
			||||||
 | 
					            callbacks=[train_callback(args)],
 | 
				
			||||||
 | 
					            max_epochs=args.max_epochs,
 | 
				
			||||||
 | 
					            check_val_every_n_epoch=args.check_val_every_n_epoch,
 | 
				
			||||||
 | 
					            num_sanity_val_steps=args.num_sanity_val_steps,
 | 
				
			||||||
 | 
					            log_every_n_steps=args.log_every_n_steps,
 | 
				
			||||||
 | 
					            enable_checkpointing=args.enable_checkpointing,
 | 
				
			||||||
 | 
					            accumulate_grad_batches=args.accumulate_grad_batches,
 | 
				
			||||||
 | 
					            gradient_clip_val=args.gradient_clip_val,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        trainer = Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					            args,
 | 
				
			||||||
 | 
					            callbacks=[train_callback(args)],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if trainer.global_rank == 0:
 | 
				
			||||||
 | 
					        for n in model.state_dict():
 | 
				
			||||||
 | 
					            shape = model.state_dict()[n].shape
 | 
				
			||||||
 | 
					            shape = [i for i in shape if i != 1]
 | 
				
			||||||
 | 
					            if len(shape) > 1:
 | 
				
			||||||
 | 
					                print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                print(f"{str(shape[0]).ljust(5)}       {n}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if "deepspeed" in args.strategy:
 | 
				
			||||||
 | 
					        trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
 | 
				
			||||||
 | 
					            args.ds_bucket_mb * 1000 * 1000
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
 | 
				
			||||||
 | 
					            args.ds_bucket_mb * 1000 * 1000
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # must set shuffle=False, persistent_workers=False (because worker is in another thread)
 | 
				
			||||||
 | 
					    data_loader = DataLoader(
 | 
				
			||||||
 | 
					        train_data,
 | 
				
			||||||
 | 
					        shuffle=False,
 | 
				
			||||||
 | 
					        pin_memory=True,
 | 
				
			||||||
 | 
					        batch_size=args.micro_bsz,
 | 
				
			||||||
 | 
					        num_workers=1,
 | 
				
			||||||
 | 
					        persistent_workers=False,
 | 
				
			||||||
 | 
					        drop_last=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    trainer.fit(model, data_loader)
 | 
				
			||||||
@ -131,7 +131,7 @@ const showError = (e: any) => {
 | 
				
			|||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const errorsMap = Object.entries({
 | 
					const errorsMap = Object.entries({
 | 
				
			||||||
  'python3 ./finetune/lora/train.py': 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.',
 | 
					  'python3 ./finetune/lora/v': 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.',
 | 
				
			||||||
  'cuda out of memory': 'VRAM is not enough',
 | 
					  'cuda out of memory': 'VRAM is not enough',
 | 
				
			||||||
  'valueerror: high <= 0': 'Training data is not enough, reduce context length or add more data for training',
 | 
					  'valueerror: high <= 0': 'Training data is not enough, reduce context length or add more data for training',
 | 
				
			||||||
  '+= \'+ptx\'': 'Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"',
 | 
					  '+= \'+ptx\'': 'Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"',
 | 
				
			||||||
@ -299,7 +299,6 @@ const LoraFinetune: FC = observer(() => {
 | 
				
			|||||||
          (loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') +
 | 
					          (loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') +
 | 
				
			||||||
          (loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
 | 
					          (loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
 | 
				
			||||||
          `--data_file ${convertedDataPath} ` +
 | 
					          `--data_file ${convertedDataPath} ` +
 | 
				
			||||||
          `--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` +
 | 
					 | 
				
			||||||
          `--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
 | 
					          `--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
 | 
				
			||||||
          `--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
 | 
					          `--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
 | 
				
			||||||
          `--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
 | 
					          `--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user