rwkv5 lora finetune support (https://github.com/JL-er/RWKV-v5-lora)
This commit is contained in:
parent
b7f4dd835e
commit
81544ca8b3
@ -32,6 +32,7 @@ cleaner_thread.start()
|
|||||||
w = torch.load(model_file, map_location="cpu")
|
w = torch.load(model_file, map_location="cpu")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
vocab_size = w["emb.weight"].shape[0]
|
||||||
n_embd = w["emb.weight"].shape[1]
|
n_embd = w["emb.weight"].shape[1]
|
||||||
n_layer = 0
|
n_layer = 0
|
||||||
keys = list(w.keys())
|
keys = list(w.keys())
|
||||||
@ -52,6 +53,9 @@ for x in keys:
|
|||||||
version = max(6, version)
|
version = max(6, version)
|
||||||
|
|
||||||
if version <= expected_max_version:
|
if version <= expected_max_version:
|
||||||
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
|
print(
|
||||||
|
f"v{int(version)}/train.py --vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"RWKV{version} is not supported")
|
raise Exception(f"RWKV{version} is not supported")
|
||||||
|
@ -47,10 +47,10 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
echo "loading $loadModel"
|
echo "loading $loadModel"
|
||||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 4)
|
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 5.2)
|
||||||
echo $modelInfo
|
echo $modelInfo
|
||||||
if [[ $modelInfo =~ "--n_layer" ]]; then
|
if [[ $modelInfo =~ "--n_layer" ]]; then
|
||||||
python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
|
python3 ./finetune/lora/$modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
|
||||||
--lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu
|
--lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu
|
||||||
else
|
else
|
||||||
echo "modelInfo is invalid"
|
echo "modelInfo is invalid"
|
||||||
|
@ -7,6 +7,7 @@ import struct
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
|
|
||||||
|
|
||||||
def print_rank_0(*message):
|
def print_rank_0(*message):
|
||||||
pass
|
pass
|
||||||
# """If distributed is initialized print only on rank 0."""
|
# """If distributed is initialized print only on rank 0."""
|
||||||
@ -16,12 +17,14 @@ def print_rank_0(*message):
|
|||||||
# else:
|
# else:
|
||||||
# print(*message, flush=True)
|
# print(*message, flush=True)
|
||||||
|
|
||||||
|
|
||||||
def _warmup_mmap_file(path):
|
def _warmup_mmap_file(path):
|
||||||
pass
|
pass
|
||||||
# with open(path, "rb") as stream:
|
# with open(path, "rb") as stream:
|
||||||
# while stream.read(100 * 1024 * 1024):
|
# while stream.read(100 * 1024 * 1024):
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
|
|
||||||
dtypes = {
|
dtypes = {
|
||||||
1: np.uint8,
|
1: np.uint8,
|
||||||
2: np.int8,
|
2: np.int8,
|
||||||
@ -33,18 +36,22 @@ dtypes = {
|
|||||||
8: np.uint16,
|
8: np.uint16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def code(dtype):
|
def code(dtype):
|
||||||
for k in dtypes.keys():
|
for k in dtypes.keys():
|
||||||
if dtypes[k] == dtype:
|
if dtypes[k] == dtype:
|
||||||
return k
|
return k
|
||||||
raise ValueError(dtype)
|
raise ValueError(dtype)
|
||||||
|
|
||||||
|
|
||||||
def index_file_path(prefix_path):
|
def index_file_path(prefix_path):
|
||||||
return prefix_path + ".idx"
|
return prefix_path + ".idx"
|
||||||
|
|
||||||
|
|
||||||
def data_file_path(prefix_path):
|
def data_file_path(prefix_path):
|
||||||
return prefix_path + ".bin"
|
return prefix_path + ".bin"
|
||||||
|
|
||||||
|
|
||||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
class Index(object):
|
class Index(object):
|
||||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||||
@ -217,8 +224,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
|||||||
elif isinstance(idx, slice):
|
elif isinstance(idx, slice):
|
||||||
start, stop, step = idx.indices(len(self))
|
start, stop, step = idx.indices(len(self))
|
||||||
if step != 1:
|
if step != 1:
|
||||||
raise ValueError(
|
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||||
"Slices into indexed_dataset must be contiguous")
|
|
||||||
ptr = self._index._pointers[start]
|
ptr = self._index._pointers[start]
|
||||||
sizes = self._index._sizes[idx]
|
sizes = self._index._sizes[idx]
|
||||||
offsets = list(accumulate(sizes))
|
offsets = list(accumulate(sizes))
|
@ -17,9 +17,11 @@ class MyDataset(Dataset):
|
|||||||
|
|
||||||
if args.data_type == "binidx":
|
if args.data_type == "binidx":
|
||||||
self.vocab_size = args.vocab_size
|
self.vocab_size = args.vocab_size
|
||||||
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
|
rank_zero_info(
|
||||||
|
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||||
|
)
|
||||||
|
|
||||||
if args.data_file.endswith('/'):
|
if args.data_file.endswith("/"):
|
||||||
d_all = []
|
d_all = []
|
||||||
for p in os.listdir(args.data_file):
|
for p in os.listdir(args.data_file):
|
||||||
if p.endswith(".idx"):
|
if p.endswith(".idx"):
|
||||||
@ -29,33 +31,52 @@ class MyDataset(Dataset):
|
|||||||
exit(0)
|
exit(0)
|
||||||
else:
|
else:
|
||||||
self.data = MMapIndexedDataset(args.data_file)
|
self.data = MMapIndexedDataset(args.data_file)
|
||||||
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
|
self.data_size = (
|
||||||
|
len(self.data._bin_buffer) // self.data._index._dtype_size
|
||||||
|
)
|
||||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||||
|
|
||||||
if args.my_qa_mask > 0:
|
if args.my_qa_mask > 0:
|
||||||
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
|
self.data_pile = MMapIndexedDataset(
|
||||||
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
"/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document"
|
||||||
|
)
|
||||||
|
self.data_pile_size = (
|
||||||
|
len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
||||||
|
)
|
||||||
|
|
||||||
if args.my_pile_stage > 0:
|
if args.my_pile_stage > 0:
|
||||||
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
||||||
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||||
assert self.samples_per_epoch == 40320
|
assert self.samples_per_epoch == 40320
|
||||||
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
|
rank_zero_info(
|
||||||
|
f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########"
|
||||||
|
)
|
||||||
dataset_slot = self.data_size // args.ctx_len
|
dataset_slot = self.data_size // args.ctx_len
|
||||||
if args.my_pile_stage != 4:
|
if args.my_pile_stage != 4:
|
||||||
assert MaybeIsPrime(args.magic_prime)
|
assert MaybeIsPrime(args.magic_prime)
|
||||||
assert args.magic_prime % 3 == 2
|
assert args.magic_prime % 3 == 2
|
||||||
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
|
assert (
|
||||||
|
args.magic_prime / dataset_slot > 0.99
|
||||||
|
and args.magic_prime / dataset_slot <= 1
|
||||||
|
)
|
||||||
elif args.data_type == "numpy":
|
elif args.data_type == "numpy":
|
||||||
self.data = np.load(args.data_file).astype("int")
|
self.data = np.load(args.data_file).astype("int")
|
||||||
self.vocab_size = args.vocab_size
|
self.vocab_size = args.vocab_size
|
||||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
rank_zero_info(
|
||||||
|
"Current vocab size =", self.vocab_size, "(make sure it's correct)"
|
||||||
|
)
|
||||||
self.data_size = len(self.data)
|
self.data_size = len(self.data)
|
||||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||||
elif args.data_type == "uint16":
|
elif args.data_type == "uint16":
|
||||||
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
|
self.data = (
|
||||||
|
np.fromfile(args.data_file, dtype=np.uint16)
|
||||||
|
.astype("int32")
|
||||||
|
.reshape(-1, args.my_sample_len)
|
||||||
|
)
|
||||||
self.vocab_size = args.vocab_size
|
self.vocab_size = args.vocab_size
|
||||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
rank_zero_info(
|
||||||
|
"Current vocab size =", self.vocab_size, "(make sure it's correct)"
|
||||||
|
)
|
||||||
self.data_size = self.data.shape[0]
|
self.data_size = self.data.shape[0]
|
||||||
rank_zero_info(f"Data has {self.data_size} samples.")
|
rank_zero_info(f"Data has {self.data_size} samples.")
|
||||||
elif args.data_type == "wds_img":
|
elif args.data_type == "wds_img":
|
||||||
@ -86,10 +107,14 @@ class MyDataset(Dataset):
|
|||||||
for u in unique:
|
for u in unique:
|
||||||
xxObj[xx] = u
|
xxObj[xx] = u
|
||||||
xx += 1
|
xx += 1
|
||||||
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
|
with open(
|
||||||
|
f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le"
|
||||||
|
) as vocab_file:
|
||||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||||
self.data_size = len(self.data)
|
self.data_size = len(self.data)
|
||||||
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
|
rank_zero_info(
|
||||||
|
f"Data has {self.data_size} tokens, {self.vocab_size} vocab size."
|
||||||
|
)
|
||||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||||
|
|
||||||
@ -104,27 +129,42 @@ class MyDataset(Dataset):
|
|||||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
||||||
|
|
||||||
if args.data_type == "wds_img":
|
if args.data_type == "wds_img":
|
||||||
|
|
||||||
def init_wds(self, bias=0):
|
def init_wds(self, bias=0):
|
||||||
def identity(x):
|
def identity(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
import webdataset as wds
|
import webdataset as wds
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
# img_transform = transforms.Compose(
|
# img_transform = transforms.Compose(
|
||||||
# [transforms.CenterCrop(256)]
|
# [transforms.CenterCrop(256)]
|
||||||
# )
|
# )
|
||||||
img_transform = transforms.Compose([
|
img_transform = transforms.Compose(
|
||||||
transforms.CenterCrop(512),
|
[transforms.CenterCrop(512), transforms.Resize((args.my_img_size))]
|
||||||
transforms.Resize((args.my_img_size))
|
)
|
||||||
])
|
self.data_raw = (
|
||||||
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
|
wds.WebDataset(args.data_file, resampled=True)
|
||||||
|
.shuffle(
|
||||||
|
10000,
|
||||||
|
initial=1000,
|
||||||
|
rng=random.Random(epoch * 100000 + rank + bias * 1e9),
|
||||||
|
)
|
||||||
|
.decode("torchrgb")
|
||||||
|
.to_tuple("jpg", "json", "txt")
|
||||||
|
.map_tuple(img_transform, identity, identity)
|
||||||
|
)
|
||||||
for pp in self.data_raw.pipeline:
|
for pp in self.data_raw.pipeline:
|
||||||
if 'Resampled' in str(pp):
|
if "Resampled" in str(pp):
|
||||||
pp.deterministic = True
|
pp.deterministic = True
|
||||||
|
|
||||||
def worker_seed():
|
def worker_seed():
|
||||||
return rank * 100000 + epoch + bias * 1e9
|
return rank * 100000 + epoch + bias * 1e9
|
||||||
|
|
||||||
pp.worker_seed = worker_seed
|
pp.worker_seed = worker_seed
|
||||||
self.data = iter(self.data_raw)
|
self.data = iter(self.data_raw)
|
||||||
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
|
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
|
||||||
|
|
||||||
if self.data == None:
|
if self.data == None:
|
||||||
init_wds(self)
|
init_wds(self)
|
||||||
trial = 0
|
trial = 0
|
||||||
@ -133,7 +173,9 @@ class MyDataset(Dataset):
|
|||||||
dd = next(self.data) # jpg, json, txt
|
dd = next(self.data) # jpg, json, txt
|
||||||
break
|
break
|
||||||
except:
|
except:
|
||||||
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
|
print(
|
||||||
|
f"[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]"
|
||||||
|
)
|
||||||
self.error_count += 1
|
self.error_count += 1
|
||||||
init_wds(self, self.error_count)
|
init_wds(self, self.error_count)
|
||||||
trial += 1
|
trial += 1
|
||||||
@ -196,7 +238,12 @@ class MyDataset(Dataset):
|
|||||||
z_sum = 0
|
z_sum = 0
|
||||||
isGood = False
|
isGood = False
|
||||||
for i in range(3, ctx_len):
|
for i in range(3, ctx_len):
|
||||||
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
|
if (
|
||||||
|
dix[i] == 27
|
||||||
|
and dix[i - 1] == 34
|
||||||
|
and dix[i - 2] == 187
|
||||||
|
and dix[i - 3] == 187
|
||||||
|
):
|
||||||
isGood = True
|
isGood = True
|
||||||
if dix[i] == 0:
|
if dix[i] == 0:
|
||||||
isGood = False
|
isGood = False
|
||||||
@ -206,7 +253,9 @@ class MyDataset(Dataset):
|
|||||||
if z_sum == 0:
|
if z_sum == 0:
|
||||||
z = [1] * ctx_len
|
z = [1] * ctx_len
|
||||||
i = np.random.randint(0, self.data_pile_size - req_len)
|
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||||
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
|
dix = self.data_pile.get(
|
||||||
|
idx=0, offset=i, length=req_len
|
||||||
|
).astype(int)
|
||||||
z = torch.tensor(z, dtype=torch.bfloat16)
|
z = torch.tensor(z, dtype=torch.bfloat16)
|
||||||
|
|
||||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
x = torch.tensor(dix[:-1], dtype=torch.long)
|
@ -5,6 +5,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import os, math, gc, importlib
|
import os, math, gc, importlib
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# torch._C._jit_set_profiling_executor(True)
|
# torch._C._jit_set_profiling_executor(True)
|
||||||
# torch._C._jit_set_profiling_mode(True)
|
# torch._C._jit_set_profiling_mode(True)
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -13,7 +14,8 @@ from torch.nn import functional as F
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||||
if importlib.util.find_spec('deepspeed'):
|
|
||||||
|
if importlib.util.find_spec("deepspeed"):
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
||||||
|
|
||||||
@ -28,9 +30,10 @@ LORA_CONFIG = {
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
|
print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"])
|
||||||
except:
|
except:
|
||||||
os.environ["RWKV_MY_TESTING"] = ''
|
os.environ["RWKV_MY_TESTING"] = ""
|
||||||
|
|
||||||
|
|
||||||
def __nop(ob):
|
def __nop(ob):
|
||||||
return ob
|
return ob
|
||||||
@ -53,7 +56,26 @@ T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
|
|||||||
from torch.utils.cpp_extension import load
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||||
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["finetune/lora/cuda/wkv_op_bf16.cpp", "finetune/lora/cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
wkv_cuda = load(
|
||||||
|
name=f"wkv_{T_MAX}_bf16",
|
||||||
|
sources=[
|
||||||
|
"finetune/lora/v4/cuda/wkv_op_bf16.cpp",
|
||||||
|
"finetune/lora/v4/cuda/wkv_cuda_bf16.cu",
|
||||||
|
],
|
||||||
|
verbose=True,
|
||||||
|
extra_cuda_cflags=[
|
||||||
|
"-t 4",
|
||||||
|
"-std=c++17",
|
||||||
|
"-res-usage",
|
||||||
|
"--maxrregcount 60",
|
||||||
|
"--use_fast_math",
|
||||||
|
"-O3",
|
||||||
|
"-Xptxas -O3",
|
||||||
|
"--extra-device-vectorization",
|
||||||
|
f"-DTmax={T_MAX}",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
class WKV(torch.autograd.Function):
|
class WKV(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, B, T, C, w, u, k, v):
|
def forward(ctx, B, T, C, w, u, k, v):
|
||||||
@ -66,10 +88,16 @@ if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|||||||
u = u.contiguous()
|
u = u.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
y = torch.empty(
|
||||||
|
(B, T, C),
|
||||||
|
device=w.device,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||||
ctx.save_for_backward(w, u, k, v, y)
|
ctx.save_for_backward(w, u, k, v, y)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, gy):
|
def backward(ctx, gy):
|
||||||
B = ctx.B
|
B = ctx.B
|
||||||
@ -78,16 +106,54 @@ if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|||||||
assert T <= T_MAX
|
assert T <= T_MAX
|
||||||
assert B * C % min(C, 32) == 0
|
assert B * C % min(C, 32) == 0
|
||||||
w, u, k, v, y = ctx.saved_tensors
|
w, u, k, v, y = ctx.saved_tensors
|
||||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
gw = torch.empty(
|
||||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
(B, C),
|
||||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
device=gy.device,
|
||||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
memory_format=torch.contiguous_format,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
gu = torch.empty(
|
||||||
|
(B, C),
|
||||||
|
device=gy.device,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
gk = torch.empty(
|
||||||
|
(B, T, C),
|
||||||
|
device=gy.device,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
gv = torch.empty(
|
||||||
|
(B, T, C),
|
||||||
|
device=gy.device,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
||||||
gw = torch.sum(gw, dim=0)
|
gw = torch.sum(gw, dim=0)
|
||||||
gu = torch.sum(gu, dim=0)
|
gu = torch.sum(gu, dim=0)
|
||||||
return (None, None, None, gw, gu, gk, gv)
|
return (None, None, None, gw, gu, gk, gv)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["finetune/lora/cuda/wkv_op.cpp", "finetune/lora/cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
wkv_cuda = load(
|
||||||
|
name=f"wkv_{T_MAX}",
|
||||||
|
sources=[
|
||||||
|
"finetune/lora/v4/cuda/wkv_op.cpp",
|
||||||
|
"finetune/lora/v4/cuda/wkv_cuda.cu",
|
||||||
|
],
|
||||||
|
verbose=True,
|
||||||
|
extra_cuda_cflags=[
|
||||||
|
"-res-usage",
|
||||||
|
"--maxrregcount 60",
|
||||||
|
"--use_fast_math",
|
||||||
|
"-O3",
|
||||||
|
"-Xptxas -O3",
|
||||||
|
"--extra-device-vectorization",
|
||||||
|
f"-DTmax={T_MAX}",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
class WKV(torch.autograd.Function):
|
class WKV(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, B, T, C, w, u, k, v):
|
def forward(ctx, B, T, C, w, u, k, v):
|
||||||
@ -106,7 +172,9 @@ else:
|
|||||||
u = u.float().contiguous()
|
u = u.float().contiguous()
|
||||||
k = k.float().contiguous()
|
k = k.float().contiguous()
|
||||||
v = v.float().contiguous()
|
v = v.float().contiguous()
|
||||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
|
y = torch.empty(
|
||||||
|
(B, T, C), device=w.device, memory_format=torch.contiguous_format
|
||||||
|
)
|
||||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||||
ctx.save_for_backward(w, u, k, v, y)
|
ctx.save_for_backward(w, u, k, v, y)
|
||||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||||
@ -115,6 +183,7 @@ else:
|
|||||||
return y.half()
|
return y.half()
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||||
return y.bfloat16()
|
return y.bfloat16()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, gy):
|
def backward(ctx, gy):
|
||||||
B = ctx.B
|
B = ctx.B
|
||||||
@ -123,14 +192,26 @@ else:
|
|||||||
assert T <= T_MAX
|
assert T <= T_MAX
|
||||||
assert B * C % min(C, 32) == 0
|
assert B * C % min(C, 32) == 0
|
||||||
w, u, k, v, y = ctx.saved_tensors
|
w, u, k, v, y = ctx.saved_tensors
|
||||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
gw = torch.empty(
|
||||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
(B, C), device=gy.device, memory_format=torch.contiguous_format
|
||||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
)
|
||||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
gu = torch.empty(
|
||||||
|
(B, C), device=gy.device, memory_format=torch.contiguous_format
|
||||||
|
)
|
||||||
|
gk = torch.empty(
|
||||||
|
(B, T, C), device=gy.device, memory_format=torch.contiguous_format
|
||||||
|
)
|
||||||
|
gv = torch.empty(
|
||||||
|
(B, T, C), device=gy.device, memory_format=torch.contiguous_format
|
||||||
|
)
|
||||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
wkv_cuda.backward(
|
||||||
|
B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
|
wkv_cuda.backward(
|
||||||
|
B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv
|
||||||
|
)
|
||||||
gw = torch.sum(gw, dim=0)
|
gw = torch.sum(gw, dim=0)
|
||||||
gu = torch.sum(gu, dim=0)
|
gu = torch.sum(gu, dim=0)
|
||||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||||
@ -138,7 +219,15 @@ else:
|
|||||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
||||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
return (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
gw.bfloat16(),
|
||||||
|
gu.bfloat16(),
|
||||||
|
gk.bfloat16(),
|
||||||
|
gv.bfloat16(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
def RUN_CUDA(B, T, C, w, u, k, v):
|
||||||
@ -151,15 +240,17 @@ def RUN_CUDA(B, T, C, w, u, k, v):
|
|||||||
|
|
||||||
|
|
||||||
class LoraLinear(nn.Module):
|
class LoraLinear(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool):
|
def __init__(self, in_features: int, out_features: int, bias: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||||
assert bias == False, "Biased LoraLinear not supported"
|
assert bias == False, "Biased LoraLinear not supported"
|
||||||
|
|
||||||
r, alpha, dropout = LORA_CONFIG["r"], LORA_CONFIG[
|
r, alpha, dropout = (
|
||||||
"alpha"], LORA_CONFIG["dropout"]
|
LORA_CONFIG["r"],
|
||||||
|
LORA_CONFIG["alpha"],
|
||||||
|
LORA_CONFIG["dropout"],
|
||||||
|
)
|
||||||
self.lora_A = nn.Parameter(torch.empty(r, in_features))
|
self.lora_A = nn.Parameter(torch.empty(r, in_features))
|
||||||
self.lora_B = nn.Parameter(torch.empty(out_features, r))
|
self.lora_B = nn.Parameter(torch.empty(out_features, r))
|
||||||
self.lora_dropout = nn.Dropout(dropout)
|
self.lora_dropout = nn.Dropout(dropout)
|
||||||
@ -170,9 +261,9 @@ class LoraLinear(nn.Module):
|
|||||||
nn.init.zeros_(self.lora_B)
|
nn.init.zeros_(self.lora_B)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return (
|
return F.linear(x, self.weight) + self.scaling * F.linear(
|
||||||
F.linear(x, self.weight) + self.scaling *
|
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
|
||||||
F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B))
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.wraps(LoraLinear)
|
@functools.wraps(LoraLinear)
|
||||||
@ -214,17 +305,23 @@ class RWKV_TimeMix(MyModule):
|
|||||||
# fancy time_decay
|
# fancy time_decay
|
||||||
decay_speed = torch.ones(args.dim_att)
|
decay_speed = torch.ones(args.dim_att)
|
||||||
for h in range(args.dim_att):
|
for h in range(args.dim_att):
|
||||||
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (
|
||||||
|
0.7 + 1.3 * ratio_0_to_1
|
||||||
|
)
|
||||||
self.time_decay = nn.Parameter(decay_speed)
|
self.time_decay = nn.Parameter(decay_speed)
|
||||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||||||
|
|
||||||
# fancy time_first
|
# fancy time_first
|
||||||
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
|
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
|
||||||
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
|
self.time_first = nn.Parameter(
|
||||||
|
torch.ones(args.dim_att) * math.log(0.3) + zigzag
|
||||||
|
)
|
||||||
|
|
||||||
# fancy time_mix
|
# fancy time_mix
|
||||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||||
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
self.time_mix_v = nn.Parameter(
|
||||||
|
torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||||
|
)
|
||||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||||
|
|
||||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||||
@ -235,8 +332,10 @@ class RWKV_TimeMix(MyModule):
|
|||||||
|
|
||||||
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
||||||
|
|
||||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
if "a" in os.environ["RWKV_MY_TESTING"]:
|
||||||
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
self.register_buffer(
|
||||||
|
"att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||||
|
)
|
||||||
d_qkv = args.n_embd // 16
|
d_qkv = args.n_embd // 16
|
||||||
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
|
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||||
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
|
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||||
@ -245,12 +344,17 @@ class RWKV_TimeMix(MyModule):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||||
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||||
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
self.time_mix_vv = nn.Parameter(
|
||||||
|
torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||||
|
)
|
||||||
|
|
||||||
|
if "a" not in os.environ["RWKV_MY_TESTING"]:
|
||||||
|
|
||||||
if 'a' not in os.environ["RWKV_MY_TESTING"]:
|
|
||||||
@MyFunction
|
@MyFunction
|
||||||
def jit_func(self, x):
|
def jit_func(self, x):
|
||||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
xx = self.time_shift(
|
||||||
|
x
|
||||||
|
) # Mix x with the previous timestep to produce xk, xv, xr
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||||
@ -263,21 +367,26 @@ class RWKV_TimeMix(MyModule):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||||
sr, k, v = self.jit_func(x)
|
sr, k, v = self.jit_func(x)
|
||||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
rwkv = sr * RUN_CUDA(
|
||||||
|
B, T, self.args.dim_att, self.time_decay, self.time_first, k, v
|
||||||
|
)
|
||||||
return self.output(rwkv)
|
return self.output(rwkv)
|
||||||
|
|
||||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
if "a" in os.environ["RWKV_MY_TESTING"]:
|
||||||
|
|
||||||
@MyFunction
|
@MyFunction
|
||||||
def QKV(self, q, k, v):
|
def QKV(self, q, k, v):
|
||||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||||
att = att.masked_fill(self.att_mask == 0, float('-inf'))
|
att = att.masked_fill(self.att_mask == 0, float("-inf"))
|
||||||
att = F.softmax(att, dim=-1)
|
att = F.softmax(att, dim=-1)
|
||||||
x = att @ v
|
x = att @ v
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@MyFunction
|
@MyFunction
|
||||||
def jit_funcQKV(self, x):
|
def jit_funcQKV(self, x):
|
||||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
xx = self.time_shift(
|
||||||
|
x
|
||||||
|
) # Mix x with the previous timestep to produce xk, xv, xr
|
||||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||||
@ -296,12 +405,16 @@ class RWKV_TimeMix(MyModule):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||||
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
|
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
|
||||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
rwkv = sr * RUN_CUDA(
|
||||||
|
B, T, self.args.dim_att, self.time_decay, self.time_first, k, v
|
||||||
|
)
|
||||||
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
|
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
|
||||||
return rwkv
|
return rwkv
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
|
|
||||||
class RWKV_ChannelMix(MyModule):
|
class RWKV_ChannelMix(MyModule):
|
||||||
def __init__(self, args, layer_id):
|
def __init__(self, args, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -331,6 +444,7 @@ class RWKV_ChannelMix(MyModule):
|
|||||||
kv = self.value(k)
|
kv = self.value(k)
|
||||||
return torch.sigmoid(self.receptance(xr)) * kv
|
return torch.sigmoid(self.receptance(xr)) * kv
|
||||||
|
|
||||||
|
|
||||||
class MishGLU(MyModule):
|
class MishGLU(MyModule):
|
||||||
def __init__(self, args, layer_id):
|
def __init__(self, args, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -360,6 +474,7 @@ class MishGLU(MyModule):
|
|||||||
b = self.bb(xb)
|
b = self.bb(xb)
|
||||||
return self.value(a * F.mish(b))
|
return self.value(a * F.mish(b))
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
# The RWKV Model with our blocks
|
# The RWKV Model with our blocks
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
@ -377,15 +492,19 @@ class Block(nn.Module):
|
|||||||
if self.layer_id == 0:
|
if self.layer_id == 0:
|
||||||
self.ln0 = nn.LayerNorm(args.n_embd)
|
self.ln0 = nn.LayerNorm(args.n_embd)
|
||||||
if args.my_pos_emb > 0:
|
if args.my_pos_emb > 0:
|
||||||
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
self.pos_emb_x = nn.Parameter(
|
||||||
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
torch.zeros((1, args.my_pos_emb, args.n_embd))
|
||||||
|
)
|
||||||
|
self.pos_emb_y = nn.Parameter(
|
||||||
|
torch.zeros((args.my_pos_emb, 1, args.n_embd))
|
||||||
|
)
|
||||||
|
|
||||||
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
||||||
self.ffnPre = RWKV_ChannelMix(args, 0)
|
self.ffnPre = RWKV_ChannelMix(args, 0)
|
||||||
else:
|
else:
|
||||||
self.att = RWKV_TimeMix(args, layer_id)
|
self.att = RWKV_TimeMix(args, layer_id)
|
||||||
|
|
||||||
if 'g' in os.environ["RWKV_MY_TESTING"]:
|
if "g" in os.environ["RWKV_MY_TESTING"]:
|
||||||
self.ffn = MishGLU(args, layer_id)
|
self.ffn = MishGLU(args, layer_id)
|
||||||
else:
|
else:
|
||||||
self.ffn = RWKV_ChannelMix(args, layer_id)
|
self.ffn = RWKV_ChannelMix(args, layer_id)
|
||||||
@ -395,7 +514,9 @@ class Block(nn.Module):
|
|||||||
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||||
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||||
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
||||||
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
self.register_buffer(
|
||||||
|
"tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, x_emb=None):
|
def forward(self, x, x_emb=None):
|
||||||
args = self.args
|
args = self.args
|
||||||
@ -443,13 +564,13 @@ class RWKV(pl.LightningModule):
|
|||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
if not hasattr(args, 'dim_att'):
|
if not hasattr(args, "dim_att"):
|
||||||
args.dim_att = args.n_embd
|
args.dim_att = args.n_embd
|
||||||
if not hasattr(args, 'dim_ffn'):
|
if not hasattr(args, "dim_ffn"):
|
||||||
args.dim_ffn = args.n_embd * 4
|
args.dim_ffn = args.n_embd * 4
|
||||||
if not hasattr(args, 'tiny_att_layer'):
|
if not hasattr(args, "tiny_att_layer"):
|
||||||
args.tiny_att_layer = -1
|
args.tiny_att_layer = -1
|
||||||
if not hasattr(args, 'tiny_att_dim'):
|
if not hasattr(args, "tiny_att_dim"):
|
||||||
args.tiny_att_dim = -1
|
args.tiny_att_dim = -1
|
||||||
|
|
||||||
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
||||||
@ -462,7 +583,9 @@ class RWKV(pl.LightningModule):
|
|||||||
if args.head_qk > 0:
|
if args.head_qk > 0:
|
||||||
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||||
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||||
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
self.register_buffer(
|
||||||
|
"copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||||
|
)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
args = self.args
|
args = self.args
|
||||||
@ -494,19 +617,46 @@ class RWKV(pl.LightningModule):
|
|||||||
param_dict = {n: p for n, p in self.named_parameters()}
|
param_dict = {n: p for n, p in self.named_parameters()}
|
||||||
if args.my_pile_stage == 2:
|
if args.my_pile_stage == 2:
|
||||||
optim_groups = [
|
optim_groups = [
|
||||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
{
|
||||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
|
"params": [param_dict[n] for n in lr_1x],
|
||||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 1.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_2x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 5.0,
|
||||||
|
}, # test: 2e-3 / args.lr_init},
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_3x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 5.0,
|
||||||
|
}, # test: 3e-3 / args.lr_init},
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
optim_groups = [
|
optim_groups = [
|
||||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
{
|
||||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
"params": [param_dict[n] for n in lr_1x],
|
||||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 1.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_2x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 2.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_3x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 3.0,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
optim_groups = [
|
optim_groups = [
|
||||||
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
{
|
||||||
|
"params": [p for n, p in self.named_parameters()],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
for g in optim_groups:
|
for g in optim_groups:
|
||||||
@ -514,8 +664,26 @@ class RWKV(pl.LightningModule):
|
|||||||
optim_groups = [g for g in optim_groups if len(g["params"]) > 0]
|
optim_groups = [g for g in optim_groups if len(g["params"]) > 0]
|
||||||
|
|
||||||
if self.deepspeed_offload:
|
if self.deepspeed_offload:
|
||||||
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
|
return DeepSpeedCPUAdam(
|
||||||
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
optim_groups,
|
||||||
|
lr=self.args.lr_init,
|
||||||
|
betas=self.args.betas,
|
||||||
|
eps=self.args.adam_eps,
|
||||||
|
bias_correction=True,
|
||||||
|
adamw_mode=False,
|
||||||
|
weight_decay=0,
|
||||||
|
amsgrad=False,
|
||||||
|
)
|
||||||
|
return FusedAdam(
|
||||||
|
optim_groups,
|
||||||
|
lr=self.args.lr_init,
|
||||||
|
betas=self.args.betas,
|
||||||
|
eps=self.args.adam_eps,
|
||||||
|
bias_correction=True,
|
||||||
|
adam_w_mode=False,
|
||||||
|
weight_decay=0,
|
||||||
|
amsgrad=False,
|
||||||
|
)
|
||||||
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -589,10 +757,14 @@ class RWKV(pl.LightningModule):
|
|||||||
|
|
||||||
logits = self(idx)
|
logits = self(idx)
|
||||||
if sum_mask == mask.shape[0]:
|
if sum_mask == mask.shape[0]:
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
loss = F.cross_entropy(
|
||||||
|
logits.view(-1, logits.size(-1)), targets.view(-1)
|
||||||
|
)
|
||||||
# print('rank', self.global_rank, 'loss', loss.item())
|
# print('rank', self.global_rank, 'loss', loss.item())
|
||||||
else:
|
else:
|
||||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
loss = F.cross_entropy(
|
||||||
|
logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none"
|
||||||
|
)
|
||||||
# loss_raw = loss
|
# loss_raw = loss
|
||||||
loss = torch.sum(loss * mask) / sum_mask
|
loss = torch.sum(loss * mask) / sum_mask
|
||||||
|
|
||||||
@ -632,7 +804,14 @@ class RWKV(pl.LightningModule):
|
|||||||
|
|
||||||
gain = 1.0
|
gain = 1.0
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
|
if (
|
||||||
|
"ln_" in n
|
||||||
|
or ".ln" in n
|
||||||
|
or "time_" in n
|
||||||
|
or "_mask" in n
|
||||||
|
or "pos_emb" in n
|
||||||
|
or ".mask." in n
|
||||||
|
):
|
||||||
m[n] = p
|
m[n] = p
|
||||||
else:
|
else:
|
||||||
if n == "emb.weight":
|
if n == "emb.weight":
|
||||||
@ -640,7 +819,19 @@ class RWKV(pl.LightningModule):
|
|||||||
else:
|
else:
|
||||||
if shape[0] > shape[1]:
|
if shape[0] > shape[1]:
|
||||||
gain = math.sqrt(shape[0] / shape[1])
|
gain = math.sqrt(shape[0] / shape[1])
|
||||||
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
|
for kk in [
|
||||||
|
".att.key.",
|
||||||
|
".att.receptance.",
|
||||||
|
".att.output.",
|
||||||
|
".att.key.",
|
||||||
|
".ffn.value.",
|
||||||
|
".ffn.receptance.",
|
||||||
|
".ffnPre.value.",
|
||||||
|
".ffnPre.receptance.",
|
||||||
|
"head_q.",
|
||||||
|
".oo.",
|
||||||
|
".rr.",
|
||||||
|
]:
|
||||||
if kk in n:
|
if kk in n:
|
||||||
scale = 0
|
scale = 0
|
||||||
if n == "head.weight":
|
if n == "head.weight":
|
||||||
@ -650,7 +841,9 @@ class RWKV(pl.LightningModule):
|
|||||||
if "head_q." in n:
|
if "head_q." in n:
|
||||||
scale = 0
|
scale = 0
|
||||||
|
|
||||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
|
print(
|
||||||
|
f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}"
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.accelerator.upper() == "GPU":
|
if self.args.accelerator.upper() == "GPU":
|
||||||
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
@ -5,15 +5,17 @@ import pytorch_lightning as pl
|
|||||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||||
from .model import LORA_CONFIG
|
from .model import LORA_CONFIG
|
||||||
|
|
||||||
|
|
||||||
def my_save(dd, ff):
|
def my_save(dd, ff):
|
||||||
if '14b-run1' not in ff:
|
if "14b-run1" not in ff:
|
||||||
torch.save(dd, ff)
|
torch.save(dd, ff)
|
||||||
else:
|
else:
|
||||||
fn = ff.split('/')[-1]
|
fn = ff.split("/")[-1]
|
||||||
fff = '/dev/shm/' + fn
|
fff = "/dev/shm/" + fn
|
||||||
torch.save(dd, fff)
|
torch.save(dd, fff)
|
||||||
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
||||||
|
|
||||||
|
|
||||||
class train_callback(pl.Callback):
|
class train_callback(pl.Callback):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -38,7 +40,9 @@ class train_callback(pl.Callback):
|
|||||||
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
||||||
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
||||||
else: # exp decay
|
else: # exp decay
|
||||||
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
lr = args.lr_init * math.exp(
|
||||||
|
math.log(args.lr_final / args.lr_init) * pow(progress, 1)
|
||||||
|
)
|
||||||
|
|
||||||
if trainer.global_step < w_step:
|
if trainer.global_step < w_step:
|
||||||
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
||||||
@ -60,7 +64,9 @@ class train_callback(pl.Callback):
|
|||||||
trainer.my_loss_sum = 0
|
trainer.my_loss_sum = 0
|
||||||
trainer.my_loss_count = 0
|
trainer.my_loss_count = 0
|
||||||
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
||||||
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
|
trainer.my_log.write(
|
||||||
|
f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
print(f"\n{trainer.strategy.config}\n")
|
print(f"\n{trainer.strategy.config}\n")
|
||||||
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
||||||
@ -70,6 +76,7 @@ class train_callback(pl.Callback):
|
|||||||
if len(args.wandb) > 0:
|
if len(args.wandb) > 0:
|
||||||
print("Login to wandb...")
|
print("Login to wandb...")
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
wandb.init(
|
wandb.init(
|
||||||
project=args.wandb,
|
project=args.wandb,
|
||||||
name=args.run_name + " " + args.my_timestamp,
|
name=args.run_name + " " + args.my_timestamp,
|
||||||
@ -102,20 +109,26 @@ class train_callback(pl.Callback):
|
|||||||
# self.log("s", real_step, prog_bar=True, on_step=True)
|
# self.log("s", real_step, prog_bar=True, on_step=True)
|
||||||
|
|
||||||
if len(args.wandb) > 0:
|
if len(args.wandb) > 0:
|
||||||
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
|
lll = {
|
||||||
|
"loss": trainer.my_loss,
|
||||||
|
"lr": trainer.my_lr,
|
||||||
|
"Gtokens": real_step * token_per_step / 1e9,
|
||||||
|
}
|
||||||
if kt_s > 0:
|
if kt_s > 0:
|
||||||
lll["kt/s"] = kt_s
|
lll["kt/s"] = kt_s
|
||||||
trainer.my_wandb.log(lll, step=int(real_step))
|
trainer.my_wandb.log(lll, step=int(real_step))
|
||||||
if args.magic_prime > 0:
|
if args.magic_prime > 0:
|
||||||
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
||||||
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
|
if (
|
||||||
|
int(real_step)
|
||||||
|
== int(args.magic_prime * expand_factor // args.real_bsz) - 1
|
||||||
|
):
|
||||||
to_save_dict = pl_module.state_dict()
|
to_save_dict = pl_module.state_dict()
|
||||||
my_save(
|
my_save(
|
||||||
to_save_dict,
|
to_save_dict,
|
||||||
f"{args.proj_dir}/rwkv-final.pth",
|
f"{args.proj_dir}/rwkv-final.pth",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def on_train_epoch_start(self, trainer, pl_module):
|
def on_train_epoch_start(self, trainer, pl_module):
|
||||||
args = self.args
|
args = self.args
|
||||||
dataset = trainer.train_dataloader.dataset.datasets
|
dataset = trainer.train_dataloader.dataset.datasets
|
||||||
@ -128,24 +141,28 @@ class train_callback(pl.Callback):
|
|||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
args = self.args
|
args = self.args
|
||||||
if trainer.is_global_zero: # logging & save state_dict
|
if trainer.is_global_zero: # logging & save state_dict
|
||||||
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
|
if (
|
||||||
if args.data_type == 'wds_img':
|
args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0
|
||||||
|
) or trainer.current_epoch == args.epoch_count - 1:
|
||||||
|
if args.data_type == "wds_img":
|
||||||
raw_dict = pl_module.state_dict()
|
raw_dict = pl_module.state_dict()
|
||||||
to_save_dict = {}
|
to_save_dict = {}
|
||||||
for k in raw_dict:
|
for k in raw_dict:
|
||||||
if k.startswith('encoder.') or k.startswith('decoder.'):
|
if k.startswith("encoder.") or k.startswith("decoder."):
|
||||||
to_save_dict[k] = raw_dict[k]
|
to_save_dict[k] = raw_dict[k]
|
||||||
else:
|
else:
|
||||||
to_save_dict = pl_module.state_dict()
|
to_save_dict = pl_module.state_dict()
|
||||||
|
|
||||||
if args.lora:
|
if args.lora:
|
||||||
enable_time_finetune = 'time' in LORA_CONFIG["parts"]
|
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||||
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
|
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||||
lora_dict = {}
|
lora_dict = {}
|
||||||
for name, state in to_save_dict.items():
|
for name, state in to_save_dict.items():
|
||||||
if ('.lora_' in name
|
if (
|
||||||
or (enable_time_finetune and '.time_' in name)
|
".lora_" in name
|
||||||
or (enable_ln_finetune and '.ln' in name)):
|
or (enable_time_finetune and ".time_" in name)
|
||||||
|
or (enable_ln_finetune and ".ln" in name)
|
||||||
|
):
|
||||||
lora_dict[name] = state
|
lora_dict[name] = state
|
||||||
to_save_dict = lora_dict
|
to_save_dict = lora_dict
|
||||||
|
|
||||||
@ -155,8 +172,10 @@ class train_callback(pl.Callback):
|
|||||||
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('Error\n\n', e, '\n\n')
|
print("Error\n\n", e, "\n\n")
|
||||||
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
|
trainer.my_log.write(
|
||||||
|
f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n"
|
||||||
|
)
|
||||||
trainer.my_log.flush()
|
trainer.my_log.flush()
|
||||||
|
|
||||||
trainer.my_loss_sum = 0
|
trainer.my_loss_sum = 0
|
||||||
@ -178,7 +197,7 @@ def generate_init_weight(model, init_weight_name):
|
|||||||
mm[k] = src.reshape(mm[k].shape)
|
mm[k] = src.reshape(mm[k].shape)
|
||||||
except:
|
except:
|
||||||
tmp = mm[k].squeeze().clone()
|
tmp = mm[k].squeeze().clone()
|
||||||
print(k, src.shape, '-->', mm[k].shape)
|
print(k, src.shape, "-->", mm[k].shape)
|
||||||
ss = src.shape[0]
|
ss = src.shape[0]
|
||||||
dd = tmp.shape[0]
|
dd = tmp.shape[0]
|
||||||
for i in range(dd):
|
for i in range(dd):
|
||||||
@ -191,9 +210,9 @@ def generate_init_weight(model, init_weight_name):
|
|||||||
tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii)
|
tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii)
|
||||||
mm[k] = tmp.reshape(mm[k].shape)
|
mm[k] = tmp.reshape(mm[k].shape)
|
||||||
sss = src.squeeze().float().cpu().numpy()
|
sss = src.squeeze().float().cpu().numpy()
|
||||||
print(sss[:10], '...', sss[-10:])
|
print(sss[:10], "...", sss[-10:])
|
||||||
mmm = mm[k].squeeze().float().cpu().numpy()
|
mmm = mm[k].squeeze().float().cpu().numpy()
|
||||||
print(mmm[:10], '...', mmm[-10:])
|
print(mmm[:10], "...", mmm[-10:])
|
||||||
|
|
||||||
print(f"Save to {init_weight_name}...")
|
print(f"Save to {init_weight_name}...")
|
||||||
torch.save(mm, init_weight_name)
|
torch.save(mm, init_weight_name)
|
@ -6,6 +6,7 @@ from torch.nn import functional as F
|
|||||||
time_slot = {}
|
time_slot = {}
|
||||||
time_ref = time.time_ns()
|
time_ref = time.time_ns()
|
||||||
|
|
||||||
|
|
||||||
def record_time(name):
|
def record_time(name):
|
||||||
if name not in time_slot:
|
if name not in time_slot:
|
||||||
time_slot[name] = 1e20
|
time_slot[name] = 1e20
|
||||||
@ -13,20 +14,23 @@ def record_time(name):
|
|||||||
if tt < time_slot[name]:
|
if tt < time_slot[name]:
|
||||||
time_slot[name] = tt
|
time_slot[name] = tt
|
||||||
|
|
||||||
class TOKENIZER():
|
|
||||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
class TOKENIZER:
|
||||||
if 'list' in str(type(WORD_NAME)):
|
def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"):
|
||||||
|
if "list" in str(type(WORD_NAME)):
|
||||||
self.charMode = False
|
self.charMode = False
|
||||||
if WORD_NAME[0] == WORD_NAME[1]:
|
if WORD_NAME[0] == WORD_NAME[1]:
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
||||||
else:
|
else:
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
||||||
self.vocab_size = len(self.tokenizer)
|
self.vocab_size = len(self.tokenizer)
|
||||||
else:
|
else:
|
||||||
self.charMode = True
|
self.charMode = True
|
||||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file:
|
||||||
self.word_table = json.load(result_file)
|
self.word_table = json.load(result_file)
|
||||||
|
|
||||||
self.vocab_size = len(self.word_table)
|
self.vocab_size = len(self.word_table)
|
||||||
@ -37,23 +41,25 @@ class TOKENIZER():
|
|||||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||||
|
|
||||||
def refine_context(self, context):
|
def refine_context(self, context):
|
||||||
context = context.strip().split('\n')
|
context = context.strip().split("\n")
|
||||||
for c in range(len(context)):
|
for c in range(len(context)):
|
||||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
context[c] = context[c].strip().strip("\u3000").strip("\r")
|
||||||
context = list(filter(lambda c: c != '', context))
|
context = list(filter(lambda c: c != "", context))
|
||||||
context = '\n' + ('\n'.join(context)).strip()
|
context = "\n" + ("\n".join(context)).strip()
|
||||||
if context == '':
|
if context == "":
|
||||||
context = '\n'
|
context = "\n"
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
def sample_logits(
|
||||||
|
self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None
|
||||||
|
):
|
||||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||||
lastChar = int(x[-1])
|
lastChar = int(x[-1])
|
||||||
|
|
||||||
probs = F.softmax(out, dim=-1)
|
probs = F.softmax(out, dim=-1)
|
||||||
|
|
||||||
if self.charMode:
|
if self.charMode:
|
||||||
if self.itos[lastChar] == '\n':
|
if self.itos[lastChar] == "\n":
|
||||||
top_p = top_p_newline
|
top_p = top_p_newline
|
||||||
else:
|
else:
|
||||||
top_p = top_p_usual
|
top_p = top_p_usual
|
||||||
@ -81,6 +87,7 @@ class TOKENIZER():
|
|||||||
out = torch.multinomial(probs, num_samples=1)[0]
|
out = torch.multinomial(probs, num_samples=1)[0]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def MaybeIsPrime(number):
|
def MaybeIsPrime(number):
|
||||||
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
||||||
return True
|
return True
|
||||||
@ -121,7 +128,9 @@ def MillerRabinPrimalityTest(number):
|
|||||||
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
||||||
iterationNumber = 1
|
iterationNumber = 1
|
||||||
|
|
||||||
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
|
while (iterationNumber <= timesTwoDividNumber - 1) and (
|
||||||
|
randomNumberWithPower != number - 1
|
||||||
|
):
|
||||||
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
||||||
iterationNumber = iterationNumber + 1
|
iterationNumber = iterationNumber + 1
|
||||||
if randomNumberWithPower != (number - 1):
|
if randomNumberWithPower != (number - 1):
|
@ -184,7 +184,7 @@ if __name__ == "__main__":
|
|||||||
args.num_sanity_val_steps = 0
|
args.num_sanity_val_steps = 0
|
||||||
args.check_val_every_n_epoch = int(1e20)
|
args.check_val_every_n_epoch = int(1e20)
|
||||||
args.log_every_n_steps = int(1e20)
|
args.log_every_n_steps = int(1e20)
|
||||||
args.max_epochs = args.epoch_count # continue forever
|
args.max_epochs = args.epoch_count # -1 continue forever
|
||||||
args.betas = (args.beta1, args.beta2)
|
args.betas = (args.beta1, args.beta2)
|
||||||
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
||||||
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
202
finetune/lora/v5/cuda/wkv5_cuda.cu
vendored
Normal file
202
finetune/lora/v5/cuda/wkv5_cuda.cu
vendored
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
#include <stdio.h>
|
||||||
|
#include <assert.h>
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
typedef at::BFloat16 bf16;
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
|
||||||
|
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
||||||
|
F *__restrict__ const _y)
|
||||||
|
{
|
||||||
|
const int b = blockIdx.x / H;
|
||||||
|
const int h = blockIdx.x % H;
|
||||||
|
const int i = threadIdx.x;
|
||||||
|
_w += h*_N_;
|
||||||
|
_u += h*_N_;
|
||||||
|
|
||||||
|
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
||||||
|
float state[_N_] = {0};
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
w[i] = _w[i];
|
||||||
|
u[i] = float(_u[i]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
r[i] = float(_r[t]);
|
||||||
|
k[i] = float(_k[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float v = float(_v[t]);
|
||||||
|
float y = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j+=4)
|
||||||
|
{
|
||||||
|
const float4& r_ = (float4&)(r[j]);
|
||||||
|
const float4& k_ = (float4&)(k[j]);
|
||||||
|
const float4& w_ = (float4&)(w[j]);
|
||||||
|
const float4& u_ = (float4&)(u[j]);
|
||||||
|
float4& s = (float4&)(state[j]);
|
||||||
|
float4 x;
|
||||||
|
|
||||||
|
x.x = k_.x * v;
|
||||||
|
x.y = k_.y * v;
|
||||||
|
x.z = k_.z * v;
|
||||||
|
x.w = k_.w * v;
|
||||||
|
|
||||||
|
y += r_.x * (u_.x * x.x + s.x);
|
||||||
|
y += r_.y * (u_.y * x.y + s.y);
|
||||||
|
y += r_.z * (u_.z * x.z + s.z);
|
||||||
|
y += r_.w * (u_.w * x.w + s.w);
|
||||||
|
|
||||||
|
s.x = s.x * w_.x + x.x;
|
||||||
|
s.y = s.y * w_.y + x.y;
|
||||||
|
s.z = s.z * w_.z + x.z;
|
||||||
|
s.w = s.w * w_.w + x.w;
|
||||||
|
}
|
||||||
|
_y[t] = F(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_backward(const int B, const int T, const int C, const int H,
|
||||||
|
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
|
||||||
|
F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
|
||||||
|
{
|
||||||
|
const int b = blockIdx.x / H;
|
||||||
|
const int h = blockIdx.x % H;
|
||||||
|
const int i = threadIdx.x;
|
||||||
|
_w += h*_N_;
|
||||||
|
_u += h*_N_;
|
||||||
|
__w += h*_N_;
|
||||||
|
|
||||||
|
__shared__ float w_[_N_], u_[_N_];
|
||||||
|
__shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
|
||||||
|
__syncthreads();
|
||||||
|
w_[i] = _w[i];
|
||||||
|
u_[i] = float(_u[i]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float w = w_[i];
|
||||||
|
const float ww = __w[i];
|
||||||
|
const float u = u_[i];
|
||||||
|
|
||||||
|
float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
|
||||||
|
|
||||||
|
float gw = 0, gu = 0;
|
||||||
|
const int t000 = b*T*C + h*_N_ + i;
|
||||||
|
const int t111 = (b+1)*T*C + h*_N_ + i;
|
||||||
|
const int t222 = t111 - 2*C;
|
||||||
|
|
||||||
|
for (int t = t000; t < t111; t += C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
v[i] = float(_v[t]);
|
||||||
|
gy[i] = float(_gy[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float k = float(_k[t]);
|
||||||
|
float gr = 0, gu_ = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = state[j];
|
||||||
|
float x = k * v[j];
|
||||||
|
|
||||||
|
gr += (u * x + s) * gy[j];
|
||||||
|
gu_ += x * gy[j];
|
||||||
|
s = s * w + x;
|
||||||
|
}
|
||||||
|
_gr[t] = F(gr);
|
||||||
|
gu += float(_r[t]) * gu_;
|
||||||
|
}
|
||||||
|
_gu[b*C + h*_N_ + i] = F(gu);
|
||||||
|
|
||||||
|
for (int t = t000; t < t222; t += C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
v[i] = float(_v[t]);
|
||||||
|
gy[i] = float(_gy[t + 2*C]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float k = float(_k[t]);
|
||||||
|
float gw_ = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = saaaa[j];
|
||||||
|
float& s2 = sbbbb[j];
|
||||||
|
float x = k * v[j];
|
||||||
|
|
||||||
|
float tmp = w * (x + s);
|
||||||
|
s = tmp;
|
||||||
|
s2 = tmp + w * s2;
|
||||||
|
gw_ += s2 * gy[j];
|
||||||
|
}
|
||||||
|
gw += float(_r[t + 2*C]) * gw_;
|
||||||
|
}
|
||||||
|
_gw[b*C + h*_N_ + i] = F(ww * gw);
|
||||||
|
|
||||||
|
for (int t = t111 - C; t >= t000; t -= C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
v[i] = float(_v[t]);
|
||||||
|
gy[i] = float(_gy[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float rr = float(_r[t]);
|
||||||
|
float gk = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = scccc[j];
|
||||||
|
float x = rr * gy[j];
|
||||||
|
|
||||||
|
gk += (u * x + s) * v[j];
|
||||||
|
s = x + s * w;
|
||||||
|
}
|
||||||
|
_gk[t] = F(gk);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = t111 - C; t >= t000; t -= C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
r[i] = float(_r[t]);
|
||||||
|
k[i] = float(_k[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float gyy = float(_gy[t]);
|
||||||
|
float gv = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = sdddd[j];
|
||||||
|
float x = gyy * r[j];
|
||||||
|
|
||||||
|
gv += (u_[j] * x + s) * k[j];
|
||||||
|
s = x + s * w_[j];
|
||||||
|
}
|
||||||
|
_gv[t] = F(gv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
||||||
|
{
|
||||||
|
assert(H*_N_ == C);
|
||||||
|
assert(_N_%4 == 0);
|
||||||
|
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
|
||||||
|
{
|
||||||
|
assert(H*_N_ == C);
|
||||||
|
assert(_N_%4 == 0);
|
||||||
|
kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
|
||||||
|
}
|
22
finetune/lora/v5/cuda/wkv5_op.cpp
vendored
Normal file
22
finetune/lora/v5/cuda/wkv5_op.cpp
vendored
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
typedef at::BFloat16 bf16;
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
||||||
|
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
|
||||||
|
|
||||||
|
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||||
|
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||||
|
}
|
||||||
|
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
|
||||||
|
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
|
||||||
|
}
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &forward, "wkv5 forward");
|
||||||
|
m.def("backward", &backward, "wkv5 backward");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY(wkv5, m) {
|
||||||
|
m.def("forward", forward);
|
||||||
|
m.def("backward", backward);
|
||||||
|
}
|
0
finetune/lora/v5/src/__init__.py
vendored
Normal file
0
finetune/lora/v5/src/__init__.py
vendored
Normal file
303
finetune/lora/v5/src/binidx.py
vendored
Normal file
303
finetune/lora/v5/src/binidx.py
vendored
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
from lib2to3.pgen2 import token
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import shutil
|
||||||
|
import struct
|
||||||
|
from functools import lru_cache
|
||||||
|
from itertools import accumulate
|
||||||
|
|
||||||
|
|
||||||
|
def print_rank_0(*message):
|
||||||
|
pass
|
||||||
|
# """If distributed is initialized print only on rank 0."""
|
||||||
|
# if torch.distributed.is_initialized():
|
||||||
|
# if torch.distributed.get_rank() == 0:
|
||||||
|
# print(*message, flush=True)
|
||||||
|
# else:
|
||||||
|
# print(*message, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _warmup_mmap_file(path):
|
||||||
|
pass
|
||||||
|
# with open(path, "rb") as stream:
|
||||||
|
# while stream.read(100 * 1024 * 1024):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
|
dtypes = {
|
||||||
|
1: np.uint8,
|
||||||
|
2: np.int8,
|
||||||
|
3: np.int16,
|
||||||
|
4: np.int32,
|
||||||
|
5: np.int64,
|
||||||
|
6: float,
|
||||||
|
7: np.double,
|
||||||
|
8: np.uint16,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def code(dtype):
|
||||||
|
for k in dtypes.keys():
|
||||||
|
if dtypes[k] == dtype:
|
||||||
|
return k
|
||||||
|
raise ValueError(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def index_file_path(prefix_path):
|
||||||
|
return prefix_path + ".idx"
|
||||||
|
|
||||||
|
|
||||||
|
def data_file_path(prefix_path):
|
||||||
|
return prefix_path + ".bin"
|
||||||
|
|
||||||
|
|
||||||
|
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
|
class Index(object):
|
||||||
|
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def writer(cls, path, dtype):
|
||||||
|
class _Writer(object):
|
||||||
|
def __enter__(self):
|
||||||
|
self._file = open(path, "wb")
|
||||||
|
|
||||||
|
# Write Magic string so we can check the file format then opening it again.
|
||||||
|
self._file.write(cls._HDR_MAGIC)
|
||||||
|
# Write version number
|
||||||
|
# Little endian unsigned 64 Bit integer
|
||||||
|
self._file.write(struct.pack("<Q", 1))
|
||||||
|
# Little endian unsigned 8 Bit integer
|
||||||
|
self._file.write(struct.pack("<B", code(dtype)))
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_pointers(sizes):
|
||||||
|
dtype_size = dtype().itemsize
|
||||||
|
address = 0
|
||||||
|
pointers = []
|
||||||
|
|
||||||
|
for size in sizes:
|
||||||
|
pointers.append(address)
|
||||||
|
address += size * dtype_size
|
||||||
|
|
||||||
|
return pointers
|
||||||
|
|
||||||
|
def write(self, sizes, doc_idx):
|
||||||
|
pointers = self._get_pointers(sizes)
|
||||||
|
|
||||||
|
# Little endian unsigned 64 Bit integer
|
||||||
|
self._file.write(struct.pack("<Q", len(sizes)))
|
||||||
|
# Little endian unsigned 64 Bit integer
|
||||||
|
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||||
|
|
||||||
|
sizes = np.array(sizes, dtype=np.int32)
|
||||||
|
self._file.write(sizes.tobytes(order="C"))
|
||||||
|
del sizes
|
||||||
|
|
||||||
|
pointers = np.array(pointers, dtype=np.int64)
|
||||||
|
self._file.write(pointers.tobytes(order="C"))
|
||||||
|
del pointers
|
||||||
|
|
||||||
|
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||||
|
self._file.write(doc_idx.tobytes(order="C"))
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self._file.close()
|
||||||
|
|
||||||
|
return _Writer()
|
||||||
|
|
||||||
|
def __init__(self, path, skip_warmup=False):
|
||||||
|
with open(path, "rb") as stream:
|
||||||
|
magic_test = stream.read(9)
|
||||||
|
assert self._HDR_MAGIC == magic_test, (
|
||||||
|
"Index file doesn't match expected format. "
|
||||||
|
"Make sure that --dataset-impl is configured properly."
|
||||||
|
)
|
||||||
|
# Little endian unsigned 64 Bit integer
|
||||||
|
version = struct.unpack("<Q", stream.read(8))
|
||||||
|
assert (1,) == version
|
||||||
|
|
||||||
|
# Little endian unsigned 8 Bit integer
|
||||||
|
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||||
|
self._dtype = dtypes[dtype_code]
|
||||||
|
self._dtype_size = self._dtype().itemsize
|
||||||
|
|
||||||
|
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||||
|
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||||
|
offset = stream.tell()
|
||||||
|
|
||||||
|
if not skip_warmup:
|
||||||
|
print_rank_0(" warming up index mmap file...")
|
||||||
|
_warmup_mmap_file(path)
|
||||||
|
|
||||||
|
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||||
|
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||||
|
print_rank_0(" reading sizes...")
|
||||||
|
self._sizes = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||||
|
)
|
||||||
|
print_rank_0(" reading pointers...")
|
||||||
|
self._pointers = np.frombuffer(
|
||||||
|
self._bin_buffer,
|
||||||
|
dtype=np.int64,
|
||||||
|
count=self._len,
|
||||||
|
offset=offset + self._sizes.nbytes,
|
||||||
|
)
|
||||||
|
print_rank_0(" reading document index...")
|
||||||
|
self._doc_idx = np.frombuffer(
|
||||||
|
self._bin_buffer,
|
||||||
|
dtype=np.int64,
|
||||||
|
count=self._doc_count,
|
||||||
|
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self._bin_buffer_mmap._mmap.close()
|
||||||
|
del self._bin_buffer_mmap
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self._dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sizes(self):
|
||||||
|
return self._sizes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def doc_idx(self):
|
||||||
|
return self._doc_idx
|
||||||
|
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self._pointers[i], self._sizes[i]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._len
|
||||||
|
|
||||||
|
def __init__(self, path, skip_warmup=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._path = None
|
||||||
|
self._index = None
|
||||||
|
self._bin_buffer = None
|
||||||
|
|
||||||
|
self._do_init(path, skip_warmup)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self._do_init(state)
|
||||||
|
|
||||||
|
def _do_init(self, path, skip_warmup):
|
||||||
|
self._path = path
|
||||||
|
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||||
|
|
||||||
|
if not skip_warmup:
|
||||||
|
print_rank_0(" warming up data mmap file...")
|
||||||
|
_warmup_mmap_file(data_file_path(self._path))
|
||||||
|
print_rank_0(" creating numpy buffer of mmap...")
|
||||||
|
self._bin_buffer_mmap = np.memmap(
|
||||||
|
data_file_path(self._path), mode="r", order="C"
|
||||||
|
)
|
||||||
|
print_rank_0(" creating memory view of numpy buffer...")
|
||||||
|
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self._bin_buffer_mmap._mmap.close()
|
||||||
|
del self._bin_buffer_mmap
|
||||||
|
del self._index
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._index)
|
||||||
|
|
||||||
|
# @lru_cache(maxsize=8)
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if isinstance(idx, int):
|
||||||
|
ptr, size = self._index[idx]
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||||
|
)
|
||||||
|
return np_array
|
||||||
|
elif isinstance(idx, slice):
|
||||||
|
start, stop, step = idx.indices(len(self))
|
||||||
|
if step != 1:
|
||||||
|
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||||
|
ptr = self._index._pointers[start]
|
||||||
|
sizes = self._index._sizes[idx]
|
||||||
|
offsets = list(accumulate(sizes))
|
||||||
|
total_size = sum(sizes)
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||||
|
)
|
||||||
|
sents = np.split(np_array, offsets[:-1])
|
||||||
|
return sents
|
||||||
|
|
||||||
|
def get(self, idx, offset=0, length=None):
|
||||||
|
"""Retrieves a single item from the dataset with the option to only
|
||||||
|
return a portion of the item.
|
||||||
|
|
||||||
|
get(idx) is the same as [idx] but get() does not support slicing.
|
||||||
|
"""
|
||||||
|
ptr, size = self._index[idx]
|
||||||
|
if length is None:
|
||||||
|
length = size - offset
|
||||||
|
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||||
|
)
|
||||||
|
return np_array
|
||||||
|
|
||||||
|
def pad(self, idx, length=None):
|
||||||
|
ptr, size = self._index[idx]
|
||||||
|
try:
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||||
|
)
|
||||||
|
ptr0, _ = self._index[0]
|
||||||
|
np_array0 = np.frombuffer(
|
||||||
|
self._bin_buffer,
|
||||||
|
dtype=self._index.dtype,
|
||||||
|
count=length - size,
|
||||||
|
offset=ptr0,
|
||||||
|
)
|
||||||
|
np_array = np.append(np_array, np_array0)
|
||||||
|
return np_array
|
||||||
|
|
||||||
|
def only(self, idx):
|
||||||
|
ptr, size = self._index[idx]
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||||
|
)
|
||||||
|
|
||||||
|
return np_array
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sizes(self):
|
||||||
|
return self._index.sizes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def doc_idx(self):
|
||||||
|
return self._index.doc_idx
|
||||||
|
|
||||||
|
def get_doc_idx(self):
|
||||||
|
return self._index._doc_idx
|
||||||
|
|
||||||
|
def set_doc_idx(self, doc_idx_):
|
||||||
|
self._index._doc_idx = doc_idx_
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_prefetch(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def exists(path):
|
||||||
|
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||||
|
data_file_path(path)
|
||||||
|
)
|
241
finetune/lora/v5/src/dataset.py
vendored
Normal file
241
finetune/lora/v5/src/dataset.py
vendored
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import json, math, random, os, sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info
|
||||||
|
from .binidx import MMapIndexedDataset
|
||||||
|
from .utils import MaybeIsPrime
|
||||||
|
|
||||||
|
|
||||||
|
class MyDataset(Dataset):
|
||||||
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
if args.data_type == "binidx":
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
rank_zero_info(
|
||||||
|
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.my_pile_version == 1:
|
||||||
|
self.data = MMapIndexedDataset(args.data_file)
|
||||||
|
self.data_size = (
|
||||||
|
len(self.data._bin_buffer) // self.data._index._dtype_size
|
||||||
|
)
|
||||||
|
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||||
|
elif args.my_pile_version == 2:
|
||||||
|
data_list = (
|
||||||
|
open(args.data_file, "r", encoding="utf-8")
|
||||||
|
.read()
|
||||||
|
.strip()
|
||||||
|
.split("\n")
|
||||||
|
)
|
||||||
|
data_list = [i.strip().split(" ") for i in data_list]
|
||||||
|
self.data = []
|
||||||
|
self.data_size = int(data_list[-1][-1])
|
||||||
|
rank_zero_info(f"Data has {self.data_size} chunks.")
|
||||||
|
for d in data_list:
|
||||||
|
data = MMapIndexedDataset(d[0])
|
||||||
|
data_size = len(data._bin_buffer) // data._index._dtype_size
|
||||||
|
assert (data_size - args.ctx_len) == int(d[1])
|
||||||
|
self.data += [[int(d[-1]), int(d[1]), data]]
|
||||||
|
# rank_zero_info(self.data)
|
||||||
|
|
||||||
|
if args.my_qa_mask > 0:
|
||||||
|
# self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
|
||||||
|
self.data_pile = MMapIndexedDataset(
|
||||||
|
"/fsx/pile_deduped/pile_0.87_deduped_text_document"
|
||||||
|
)
|
||||||
|
self.data_pile_size = (
|
||||||
|
len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_pile = None
|
||||||
|
self.data_pile_size = 0
|
||||||
|
|
||||||
|
if args.my_pile_stage > 0:
|
||||||
|
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
||||||
|
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||||
|
assert self.samples_per_epoch == 40320
|
||||||
|
rank_zero_info(
|
||||||
|
f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########"
|
||||||
|
)
|
||||||
|
dataset_slot = self.data_size // args.ctx_len
|
||||||
|
if args.my_pile_stage != 4:
|
||||||
|
assert MaybeIsPrime(args.magic_prime)
|
||||||
|
assert args.magic_prime % 3 == 2
|
||||||
|
assert (
|
||||||
|
args.magic_prime / dataset_slot > 0.99
|
||||||
|
and args.magic_prime / dataset_slot <= 1
|
||||||
|
)
|
||||||
|
elif args.data_type == "numpy":
|
||||||
|
self.data = np.load(args.data_file).astype("int")
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
rank_zero_info(
|
||||||
|
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||||
|
)
|
||||||
|
self.data_size = len(self.data)
|
||||||
|
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||||
|
elif args.data_type == "uint16":
|
||||||
|
self.data = (
|
||||||
|
np.fromfile(args.data_file, dtype=np.uint16)
|
||||||
|
.astype("int32")
|
||||||
|
.reshape(-1, args.my_sample_len)
|
||||||
|
)
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
rank_zero_info(
|
||||||
|
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||||
|
)
|
||||||
|
self.data_size = self.data.shape[0]
|
||||||
|
rank_zero_info(f"Data has {self.data_size} samples.")
|
||||||
|
else:
|
||||||
|
if args.data_type == "dummy":
|
||||||
|
rank_zero_info("Building dummy data...")
|
||||||
|
self.data = ""
|
||||||
|
for i in range(100000):
|
||||||
|
aa = (i) % 10000
|
||||||
|
bb = (i * i) % 10000
|
||||||
|
cc = aa + bb
|
||||||
|
self.data += f".{aa}+{bb}={cc}."
|
||||||
|
else:
|
||||||
|
self.data = open(args.data_file, "r", encoding=args.data_type).read()
|
||||||
|
rank_zero_info("Building token list...")
|
||||||
|
unique = sorted(list(set(self.data)))
|
||||||
|
self.vocab_size = len(unique)
|
||||||
|
# rank_zero_info()
|
||||||
|
# for u in unique:
|
||||||
|
# print(u, end=' ')
|
||||||
|
# rank_zero_info('\n\n')
|
||||||
|
xx = 0
|
||||||
|
xxObj = {}
|
||||||
|
for u in unique:
|
||||||
|
xxObj[xx] = u
|
||||||
|
xx += 1
|
||||||
|
with open(
|
||||||
|
f"{args.proj_dir}/vocab.json", "w", encoding="utf-8"
|
||||||
|
) as vocab_file:
|
||||||
|
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||||
|
self.data_size = len(self.data)
|
||||||
|
rank_zero_info(
|
||||||
|
f"Data has {self.data_size} tokens, {self.vocab_size} vocab size."
|
||||||
|
)
|
||||||
|
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||||
|
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.args.epoch_steps * self.args.micro_bsz
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
args = self.args
|
||||||
|
rank = self.global_rank
|
||||||
|
epoch = self.real_epoch
|
||||||
|
world_size = self.world_size
|
||||||
|
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
||||||
|
|
||||||
|
if args.data_type == "uint16":
|
||||||
|
i = np.random.randint(0, self.data_size - 1)
|
||||||
|
dix = self.data[i]
|
||||||
|
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||||
|
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||||
|
else:
|
||||||
|
ctx_len = args.ctx_len
|
||||||
|
req_len = ctx_len + 1
|
||||||
|
magic_prime = args.magic_prime
|
||||||
|
data = self.data
|
||||||
|
|
||||||
|
if args.my_pile_stage > 0:
|
||||||
|
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
|
||||||
|
|
||||||
|
if args.my_qa_mask > 0:
|
||||||
|
ii_orig = ii
|
||||||
|
if ii % 2 == 0:
|
||||||
|
ii = -1
|
||||||
|
data = self.data_pile
|
||||||
|
else:
|
||||||
|
ii = ii // 2
|
||||||
|
if data == self.data_pile:
|
||||||
|
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||||
|
else:
|
||||||
|
if args.my_pile_stage == 4 or ii < args.my_random_steps:
|
||||||
|
# cheat: pick a random spot in dataset
|
||||||
|
if args.my_pile_version == 1:
|
||||||
|
i = np.random.randint(0, self.data_size - req_len)
|
||||||
|
else:
|
||||||
|
i = np.random.randint(0, self.data_size)
|
||||||
|
else:
|
||||||
|
ii = ii - args.my_random_steps
|
||||||
|
factor = (math.sqrt(5) - 1) / 2
|
||||||
|
factor = int(magic_prime * factor)
|
||||||
|
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
|
||||||
|
i = i + args.my_pile_shift
|
||||||
|
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
|
||||||
|
else:
|
||||||
|
# cheat: pick a random spot in dataset
|
||||||
|
i = np.random.randint(0, self.data_size - req_len)
|
||||||
|
|
||||||
|
if args.data_type == "binidx":
|
||||||
|
if args.my_pile_version == 1:
|
||||||
|
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
||||||
|
else:
|
||||||
|
# self.data : cutoff, chunk_count, data
|
||||||
|
for j in range(len(data)):
|
||||||
|
if i < data[j][0]:
|
||||||
|
ii = i
|
||||||
|
i = (i - (data[j - 1][0] if j > 0 else 0)) % data[j][1]
|
||||||
|
dix = (
|
||||||
|
data[j][2]
|
||||||
|
.get(idx=0, offset=i, length=req_len)
|
||||||
|
.astype(int)
|
||||||
|
)
|
||||||
|
# print(ii, j, i)
|
||||||
|
break
|
||||||
|
elif args.data_type == "numpy":
|
||||||
|
dix = data[i : i + req_len]
|
||||||
|
else:
|
||||||
|
dix = [self.stoi[s] for s in data[i : i + req_len]]
|
||||||
|
|
||||||
|
if args.my_qa_mask == 1:
|
||||||
|
if data == self.data_pile:
|
||||||
|
z = [1] * ctx_len
|
||||||
|
else:
|
||||||
|
z = [0] * ctx_len
|
||||||
|
z_sum = 0
|
||||||
|
isGood = False
|
||||||
|
for i in range(3, ctx_len):
|
||||||
|
if (
|
||||||
|
dix[i] == 27
|
||||||
|
and dix[i - 1] == 34
|
||||||
|
and dix[i - 2] == 187
|
||||||
|
and dix[i - 3] == 187
|
||||||
|
):
|
||||||
|
isGood = True
|
||||||
|
if dix[i] == 0:
|
||||||
|
isGood = False
|
||||||
|
if isGood:
|
||||||
|
z[i] = 1
|
||||||
|
z_sum += 1
|
||||||
|
if z_sum == 0:
|
||||||
|
z = [1] * ctx_len
|
||||||
|
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||||
|
dix = self.data_pile.get(
|
||||||
|
idx=0, offset=i, length=req_len
|
||||||
|
).astype(int)
|
||||||
|
z = torch.tensor(z, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||||
|
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||||
|
|
||||||
|
# if ii_orig < 50:
|
||||||
|
# # if rank == 1:
|
||||||
|
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
|
||||||
|
# else:
|
||||||
|
# exit(0)
|
||||||
|
|
||||||
|
if args.my_qa_mask == 1:
|
||||||
|
return x, y, z
|
||||||
|
|
||||||
|
return x, y
|
819
finetune/lora/v5/src/model.py
vendored
Normal file
819
finetune/lora/v5/src/model.py
vendored
Normal file
@ -0,0 +1,819 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
import functools
|
||||||
|
import os, math, gc, importlib
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# torch._C._jit_set_profiling_executor(True)
|
||||||
|
# torch._C._jit_set_profiling_mode(True)
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||||
|
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||||
|
|
||||||
|
if importlib.util.find_spec("deepspeed"):
|
||||||
|
import deepspeed
|
||||||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
||||||
|
|
||||||
|
|
||||||
|
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
|
||||||
|
|
||||||
|
# lora-config
|
||||||
|
LORA_CONFIG = {
|
||||||
|
"r": 0,
|
||||||
|
"alpha": 0,
|
||||||
|
"dropout": 0,
|
||||||
|
"parts": {"att", "ln", "time"},
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"])
|
||||||
|
except:
|
||||||
|
os.environ["RWKV_MY_TESTING"] = ""
|
||||||
|
|
||||||
|
|
||||||
|
def __nop(ob):
|
||||||
|
return ob
|
||||||
|
|
||||||
|
|
||||||
|
MyModule = nn.Module
|
||||||
|
MyFunction = __nop
|
||||||
|
if os.environ["RWKV_JIT_ON"] == "1":
|
||||||
|
MyModule = torch.jit.ScriptModule
|
||||||
|
MyFunction = torch.jit.script_method
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# CUDA Kernel
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
|
HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"])
|
||||||
|
wkv5_cuda = load(
|
||||||
|
name="wkv5",
|
||||||
|
sources=[
|
||||||
|
"finetune/lora/v5/cuda/wkv5_op.cpp",
|
||||||
|
f"finetune/lora/v5/cuda/wkv5_cuda.cu",
|
||||||
|
],
|
||||||
|
verbose=True,
|
||||||
|
extra_cuda_cflags=[
|
||||||
|
"-res-usage",
|
||||||
|
"--use_fast_math",
|
||||||
|
"-O3",
|
||||||
|
"-Xptxas -O3",
|
||||||
|
"--extra-device-vectorization",
|
||||||
|
f"-D_N_={HEAD_SIZE}",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WKV_5(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, B, T, C, H, r, k, v, w, u):
|
||||||
|
with torch.no_grad():
|
||||||
|
assert r.dtype == torch.bfloat16
|
||||||
|
assert k.dtype == torch.bfloat16
|
||||||
|
assert v.dtype == torch.bfloat16
|
||||||
|
assert w.dtype == torch.bfloat16
|
||||||
|
assert u.dtype == torch.bfloat16
|
||||||
|
assert HEAD_SIZE == C // H
|
||||||
|
ctx.B = B
|
||||||
|
ctx.T = T
|
||||||
|
ctx.C = C
|
||||||
|
ctx.H = H
|
||||||
|
assert r.is_contiguous()
|
||||||
|
assert k.is_contiguous()
|
||||||
|
assert v.is_contiguous()
|
||||||
|
assert w.is_contiguous()
|
||||||
|
assert u.is_contiguous()
|
||||||
|
ew = (-torch.exp(w.float())).contiguous()
|
||||||
|
eew = (torch.exp(ew)).contiguous()
|
||||||
|
ctx.save_for_backward(r, k, v, eew, ew, u)
|
||||||
|
y = torch.empty(
|
||||||
|
(B, T, C),
|
||||||
|
device=r.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
) # .uniform_(-1, 1)
|
||||||
|
wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, gy):
|
||||||
|
with torch.no_grad():
|
||||||
|
assert gy.dtype == torch.bfloat16
|
||||||
|
B = ctx.B
|
||||||
|
T = ctx.T
|
||||||
|
C = ctx.C
|
||||||
|
H = ctx.H
|
||||||
|
assert gy.is_contiguous()
|
||||||
|
r, k, v, eew, ew, u = ctx.saved_tensors
|
||||||
|
gr = torch.empty(
|
||||||
|
(B, T, C),
|
||||||
|
device=gy.device,
|
||||||
|
requires_grad=False,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
) # .uniform_(-1, 1)
|
||||||
|
gk = torch.empty(
|
||||||
|
(B, T, C),
|
||||||
|
device=gy.device,
|
||||||
|
requires_grad=False,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
) # .uniform_(-1, 1)
|
||||||
|
gv = torch.empty(
|
||||||
|
(B, T, C),
|
||||||
|
device=gy.device,
|
||||||
|
requires_grad=False,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
) # .uniform_(-1, 1)
|
||||||
|
gw = torch.empty(
|
||||||
|
(B, C),
|
||||||
|
device=gy.device,
|
||||||
|
requires_grad=False,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
) # .uniform_(-1, 1)
|
||||||
|
gu = torch.empty(
|
||||||
|
(B, C),
|
||||||
|
device=gy.device,
|
||||||
|
requires_grad=False,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
) # .uniform_(-1, 1)
|
||||||
|
wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
|
||||||
|
gw = torch.sum(gw, 0).view(H, C // H)
|
||||||
|
gu = torch.sum(gu, 0).view(H, C // H)
|
||||||
|
return (None, None, None, None, gr, gk, gv, gw, gu)
|
||||||
|
|
||||||
|
|
||||||
|
def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
|
||||||
|
return WKV_5.apply(B, T, C, H, r, k, v, w, u)
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################
|
||||||
|
class LoraLinear(nn.Module):
|
||||||
|
def __init__(self, in_features: int, out_features: int, bias: bool):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||||
|
assert bias == False, "Biased LoraLinear not supported"
|
||||||
|
|
||||||
|
r, alpha, dropout = (
|
||||||
|
LORA_CONFIG["r"],
|
||||||
|
LORA_CONFIG["alpha"],
|
||||||
|
LORA_CONFIG["dropout"],
|
||||||
|
)
|
||||||
|
self.lora_A = nn.Parameter(torch.empty(r, in_features))
|
||||||
|
self.lora_B = nn.Parameter(torch.empty(out_features, r))
|
||||||
|
self.lora_dropout = nn.Dropout(dropout)
|
||||||
|
self.scaling = alpha / r
|
||||||
|
|
||||||
|
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||||
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
|
nn.init.zeros_(self.lora_B)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.linear(x, self.weight) + self.scaling * F.linear(
|
||||||
|
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.wraps(LoraLinear)
|
||||||
|
def make_linear_att(*args, **kwargs):
|
||||||
|
if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||||||
|
return LoraLinear(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return nn.Linear(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.wraps(LoraLinear)
|
||||||
|
def make_linear_ffn(*args, **kwargs):
|
||||||
|
if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||||||
|
return LoraLinear(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return nn.Linear(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class RWKV_TimeMix_RWKV5(MyModule):
|
||||||
|
def __init__(self, args, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.head_size = args.head_size_a
|
||||||
|
assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
|
||||||
|
self.n_head = args.dim_att // self.head_size
|
||||||
|
assert args.dim_att % self.n_head == 0
|
||||||
|
self.head_size_divisor = args.head_size_divisor
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
||||||
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||||||
|
ddd = torch.ones(1, 1, args.n_embd)
|
||||||
|
for i in range(args.n_embd):
|
||||||
|
ddd[0, 0, i] = i / args.n_embd
|
||||||
|
|
||||||
|
# fancy time_mix
|
||||||
|
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||||
|
self.time_mix_v = nn.Parameter(
|
||||||
|
torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||||
|
)
|
||||||
|
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||||
|
self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||||
|
|
||||||
|
# fancy time_decay
|
||||||
|
decay_speed = torch.ones(args.dim_att)
|
||||||
|
for n in range(args.dim_att):
|
||||||
|
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
|
||||||
|
0.7 + 1.3 * ratio_0_to_1
|
||||||
|
)
|
||||||
|
self.time_decay = nn.Parameter(
|
||||||
|
decay_speed.reshape(self.n_head, self.head_size)
|
||||||
|
)
|
||||||
|
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||||||
|
|
||||||
|
tmp = torch.zeros(args.dim_att)
|
||||||
|
for n in range(args.dim_att):
|
||||||
|
zigzag = ((n + 1) % 3 - 1) * 0.1
|
||||||
|
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
|
||||||
|
|
||||||
|
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
|
||||||
|
|
||||||
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||||
|
|
||||||
|
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||||
|
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||||
|
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||||
|
|
||||||
|
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
||||||
|
self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||||
|
self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def jit_func(self, x):
|
||||||
|
B, T, C = x.size()
|
||||||
|
|
||||||
|
xx = self.time_shift(
|
||||||
|
x
|
||||||
|
) # Mix x with the previous timestep to produce xk, xv, xr
|
||||||
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||||
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||||
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||||
|
xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
|
||||||
|
|
||||||
|
r = self.receptance(xr)
|
||||||
|
k = self.key(xk)
|
||||||
|
v = self.value(xv)
|
||||||
|
g = F.silu(self.gate(xg))
|
||||||
|
|
||||||
|
return r, k, v, g
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def jit_func_2(self, x, g):
|
||||||
|
B, T, C = x.size()
|
||||||
|
x = x.view(B * T, C)
|
||||||
|
x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
|
||||||
|
x = self.output(x * g)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T, C = x.size()
|
||||||
|
H = self.n_head
|
||||||
|
r, k, v, g = self.jit_func(x)
|
||||||
|
x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
|
||||||
|
|
||||||
|
return self.jit_func_2(x, g)
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class RWKV_ChannelMix(MyModule):
|
||||||
|
def __init__(self, args, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||||
|
|
||||||
|
with torch.no_grad(): # fancy init of time_mix
|
||||||
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||||||
|
ddd = torch.ones(1, 1, args.n_embd)
|
||||||
|
for i in range(args.n_embd):
|
||||||
|
ddd[0, 0, i] = i / args.n_embd
|
||||||
|
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||||
|
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||||
|
|
||||||
|
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
|
||||||
|
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
|
||||||
|
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def forward(self, x):
|
||||||
|
xx = self.time_shift(x)
|
||||||
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||||
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||||
|
k = self.key(xk)
|
||||||
|
k = torch.relu(k) ** 2
|
||||||
|
kv = self.value(k)
|
||||||
|
return torch.sigmoid(self.receptance(xr)) * kv
|
||||||
|
|
||||||
|
|
||||||
|
class MishGLU(MyModule):
|
||||||
|
def __init__(self, args, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
|
||||||
|
|
||||||
|
x = torch.ones(1, 1, args.n_embd)
|
||||||
|
for i in range(args.n_embd):
|
||||||
|
x[0, 0, i] = i / args.n_embd
|
||||||
|
|
||||||
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||||
|
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||||
|
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||||||
|
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||||||
|
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def forward(self, x):
|
||||||
|
xx = self.time_shift(x)
|
||||||
|
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||||
|
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||||
|
a = self.aa(xa)
|
||||||
|
b = self.bb(xb)
|
||||||
|
return self.value(a * F.mish(b))
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
# The RWKV Model with our blocks
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, args, layer_id):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.layer_id = layer_id
|
||||||
|
|
||||||
|
self.ln1 = nn.LayerNorm(args.n_embd)
|
||||||
|
self.ln2 = nn.LayerNorm(args.n_embd)
|
||||||
|
|
||||||
|
if self.layer_id == 0:
|
||||||
|
self.ln0 = nn.LayerNorm(args.n_embd)
|
||||||
|
if args.my_pos_emb > 0:
|
||||||
|
self.pos_emb_x = nn.Parameter(
|
||||||
|
torch.zeros((1, args.my_pos_emb, args.n_embd))
|
||||||
|
)
|
||||||
|
self.pos_emb_y = nn.Parameter(
|
||||||
|
torch.zeros((args.my_pos_emb, 1, args.n_embd))
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
||||||
|
self.ffnPre = RWKV_ChannelMix(args, 0)
|
||||||
|
else:
|
||||||
|
self.att = RWKV_TimeMix_RWKV5(args, layer_id)
|
||||||
|
|
||||||
|
if "g" in os.environ["RWKV_MY_TESTING"]:
|
||||||
|
self.ffn = MishGLU(args, layer_id)
|
||||||
|
else:
|
||||||
|
self.ffn = RWKV_ChannelMix(args, layer_id)
|
||||||
|
|
||||||
|
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||||||
|
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
||||||
|
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||||
|
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||||
|
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
||||||
|
self.register_buffer(
|
||||||
|
"tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.dropout > 0:
|
||||||
|
self.drop0 = nn.Dropout(p=args.dropout)
|
||||||
|
self.drop1 = nn.Dropout(p=args.dropout)
|
||||||
|
|
||||||
|
def forward(self, x, x_emb=None):
|
||||||
|
args = self.args
|
||||||
|
B, T, C = x.size()
|
||||||
|
if self.layer_id == 0:
|
||||||
|
x = self.ln0(x)
|
||||||
|
if args.my_pos_emb > 0:
|
||||||
|
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :]
|
||||||
|
x = x + pos_emb
|
||||||
|
|
||||||
|
if self.args.dropout == 0:
|
||||||
|
if self.layer_id == 0 and args.pre_ffn > 0:
|
||||||
|
x = x + self.ffnPre(self.ln1(x))
|
||||||
|
else:
|
||||||
|
x = x + self.att(self.ln1(x))
|
||||||
|
x = x + self.ffn(self.ln2(x))
|
||||||
|
else:
|
||||||
|
if self.layer_id == 0 and args.pre_ffn > 0:
|
||||||
|
x = self.drop0(x + self.ffnPre(self.ln1(x)))
|
||||||
|
else:
|
||||||
|
x = self.drop0(x + self.att(self.ln1(x)))
|
||||||
|
x = self.drop1(x + self.ffn(self.ln2(x)))
|
||||||
|
|
||||||
|
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||||||
|
xx = self.tiny_ln(x)
|
||||||
|
q = self.tiny_q(xx)[:, :T, :]
|
||||||
|
k = self.tiny_k(xx)[:, :T, :]
|
||||||
|
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
||||||
|
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
||||||
|
x = x + c @ self.tiny_v(x_emb)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class L2Wrap(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, loss, y):
|
||||||
|
ctx.save_for_backward(y)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
y = ctx.saved_tensors[0]
|
||||||
|
# to encourage the logits to be close to 0
|
||||||
|
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
||||||
|
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||||
|
gy = torch.zeros_like(y)
|
||||||
|
gy.scatter_(-1, ids, maxx * factor)
|
||||||
|
return (grad_output, gy)
|
||||||
|
|
||||||
|
|
||||||
|
class RWKV(pl.LightningModule):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
if not hasattr(args, "dim_att"):
|
||||||
|
args.dim_att = args.n_embd
|
||||||
|
if not hasattr(args, "dim_ffn"):
|
||||||
|
args.dim_ffn = args.n_embd * 4
|
||||||
|
if not hasattr(args, "tiny_att_layer"):
|
||||||
|
args.tiny_att_layer = -1
|
||||||
|
if not hasattr(args, "tiny_att_dim"):
|
||||||
|
args.tiny_att_dim = -1
|
||||||
|
assert args.n_embd % 32 == 0
|
||||||
|
assert args.dim_att % 32 == 0
|
||||||
|
assert args.dim_ffn % 32 == 0
|
||||||
|
|
||||||
|
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
||||||
|
|
||||||
|
self.ln_out = nn.LayerNorm(args.n_embd)
|
||||||
|
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
if args.head_qk > 0:
|
||||||
|
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||||
|
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||||
|
self.register_buffer(
|
||||||
|
"copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||||||
|
)
|
||||||
|
if args.dropout > 0:
|
||||||
|
self.drop0 = nn.Dropout(p=args.dropout)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
args = self.args
|
||||||
|
|
||||||
|
lr_decay = set()
|
||||||
|
lr_1x = set()
|
||||||
|
lr_2x = set()
|
||||||
|
lr_3x = set()
|
||||||
|
for n, p in self.named_parameters():
|
||||||
|
if ("time_mix" in n) and (args.layerwise_lr > 0):
|
||||||
|
if args.my_pile_stage == 2:
|
||||||
|
lr_2x.add(n)
|
||||||
|
else:
|
||||||
|
lr_1x.add(n)
|
||||||
|
elif ("time_decay" in n) and (args.layerwise_lr > 0):
|
||||||
|
if args.my_pile_stage == 2:
|
||||||
|
lr_3x.add(n)
|
||||||
|
else:
|
||||||
|
lr_2x.add(n)
|
||||||
|
elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
|
||||||
|
if args.my_pile_stage == 2:
|
||||||
|
lr_2x.add(n)
|
||||||
|
else:
|
||||||
|
lr_1x.add(n)
|
||||||
|
elif ("time_first" in n) and (args.layerwise_lr > 0):
|
||||||
|
lr_3x.add(n)
|
||||||
|
elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
|
||||||
|
lr_decay.add(n)
|
||||||
|
else:
|
||||||
|
lr_1x.add(n)
|
||||||
|
|
||||||
|
lr_decay = sorted(list(lr_decay))
|
||||||
|
lr_1x = sorted(list(lr_1x))
|
||||||
|
lr_2x = sorted(list(lr_2x))
|
||||||
|
lr_3x = sorted(list(lr_3x))
|
||||||
|
# print('decay', lr_decay)
|
||||||
|
# print('1x', lr_1x)
|
||||||
|
# print('2x', lr_2x)
|
||||||
|
# print('3x', lr_3x)
|
||||||
|
param_dict = {n: p for n, p in self.named_parameters()}
|
||||||
|
|
||||||
|
if args.layerwise_lr > 0:
|
||||||
|
if args.my_pile_stage == 2:
|
||||||
|
optim_groups = [
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_1x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 1.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_2x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 5.0,
|
||||||
|
}, # test: 2e-3 / args.lr_init},
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_3x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 5.0,
|
||||||
|
}, # test: 3e-3 / args.lr_init},
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
optim_groups = [
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_1x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 1.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_2x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 2.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_3x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 3.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
optim_groups = [
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_1x],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"my_lr_scale": 1.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if args.weight_decay > 0:
|
||||||
|
optim_groups += [
|
||||||
|
{
|
||||||
|
"params": [param_dict[n] for n in lr_decay],
|
||||||
|
"weight_decay": args.weight_decay,
|
||||||
|
"my_lr_scale": 1.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if self.deepspeed_offload:
|
||||||
|
return DeepSpeedCPUAdam(
|
||||||
|
optim_groups,
|
||||||
|
lr=self.args.lr_init,
|
||||||
|
betas=self.args.betas,
|
||||||
|
eps=self.args.adam_eps,
|
||||||
|
bias_correction=True,
|
||||||
|
adamw_mode=True,
|
||||||
|
amsgrad=False,
|
||||||
|
)
|
||||||
|
return FusedAdam(
|
||||||
|
optim_groups,
|
||||||
|
lr=self.args.lr_init,
|
||||||
|
betas=self.args.betas,
|
||||||
|
eps=self.args.adam_eps,
|
||||||
|
bias_correction=True,
|
||||||
|
adam_w_mode=True,
|
||||||
|
amsgrad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.deepspeed_offload:
|
||||||
|
return DeepSpeedCPUAdam(
|
||||||
|
optim_groups,
|
||||||
|
lr=self.args.lr_init,
|
||||||
|
betas=self.args.betas,
|
||||||
|
eps=self.args.adam_eps,
|
||||||
|
bias_correction=True,
|
||||||
|
adamw_mode=False,
|
||||||
|
weight_decay=0,
|
||||||
|
amsgrad=False,
|
||||||
|
)
|
||||||
|
return FusedAdam(
|
||||||
|
optim_groups,
|
||||||
|
lr=self.args.lr_init,
|
||||||
|
betas=self.args.betas,
|
||||||
|
eps=self.args.adam_eps,
|
||||||
|
bias_correction=True,
|
||||||
|
adam_w_mode=False,
|
||||||
|
weight_decay=0,
|
||||||
|
amsgrad=False,
|
||||||
|
)
|
||||||
|
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deepspeed_offload(self) -> bool:
|
||||||
|
strategy = self.trainer.strategy
|
||||||
|
if isinstance(strategy, DeepSpeedStrategy):
|
||||||
|
cfg = strategy.config["zero_optimization"]
|
||||||
|
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def forward(self, idx):
|
||||||
|
args = self.args
|
||||||
|
B, T = idx.size()
|
||||||
|
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
||||||
|
|
||||||
|
x = self.emb(idx)
|
||||||
|
x_emb = x
|
||||||
|
|
||||||
|
if args.dropout > 0:
|
||||||
|
x = self.drop0(x)
|
||||||
|
if args.tiny_att_dim > 0:
|
||||||
|
for block in self.blocks:
|
||||||
|
if args.grad_cp == 1:
|
||||||
|
if args.lora:
|
||||||
|
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
||||||
|
else:
|
||||||
|
x = block(x, x_emb)
|
||||||
|
else:
|
||||||
|
for block in self.blocks:
|
||||||
|
if args.grad_cp == 1:
|
||||||
|
if args.lora:
|
||||||
|
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
x = deepspeed.checkpointing.checkpoint(block, x)
|
||||||
|
else:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.ln_out(x)
|
||||||
|
|
||||||
|
if args.head_qk > 0:
|
||||||
|
q = self.head_q(x)[:, :T, :]
|
||||||
|
k = self.head_k(x)[:, :T, :]
|
||||||
|
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
||||||
|
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||||
|
|
||||||
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||||
|
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
||||||
|
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||||
|
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
||||||
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||||
|
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
||||||
|
|
||||||
|
x = self.head(x) + c
|
||||||
|
else:
|
||||||
|
x = self.head(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
args = self.args
|
||||||
|
if args.my_qa_mask != 1:
|
||||||
|
idx, targets = batch
|
||||||
|
logits = self(idx)
|
||||||
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||||
|
# if '0' in os.environ["RWKV_MY_TESTING"]:
|
||||||
|
# print('logits', logits)
|
||||||
|
# torch.set_printoptions(threshold=10000)
|
||||||
|
# print('idx', idx)
|
||||||
|
# exit(0)
|
||||||
|
else:
|
||||||
|
idx, targets, mask = batch
|
||||||
|
mask = mask.view(-1)
|
||||||
|
sum_mask = torch.sum(mask).item()
|
||||||
|
# if sum_mask == 0:
|
||||||
|
# return torch.tensor([0.0], requires_grad=True)
|
||||||
|
|
||||||
|
logits = self(idx)
|
||||||
|
if sum_mask == mask.shape[0]:
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
logits.view(-1, logits.size(-1)), targets.view(-1)
|
||||||
|
)
|
||||||
|
# print('rank', self.global_rank, 'loss', loss.item())
|
||||||
|
else:
|
||||||
|
loss = F.cross_entropy(
|
||||||
|
logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none"
|
||||||
|
)
|
||||||
|
# loss_raw = loss
|
||||||
|
loss = torch.sum(loss * mask) / sum_mask
|
||||||
|
|
||||||
|
# torch.set_printoptions(threshold=10000)
|
||||||
|
# if True: #self.global_rank == 1:
|
||||||
|
# tmp = ''
|
||||||
|
# sss = 0
|
||||||
|
# ccc = 0
|
||||||
|
# for i in range(mask.shape[0]):
|
||||||
|
# if mask[i] > 0:
|
||||||
|
# tmp += str(idx.view(-1)[i].item()) + ','
|
||||||
|
# sss += loss_raw.view(-1)[i].float().item()
|
||||||
|
# ccc += 1
|
||||||
|
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
||||||
|
return L2Wrap.apply(loss, logits)
|
||||||
|
|
||||||
|
def training_step_end(self, batch_parts):
|
||||||
|
if pl.__version__[0] != "2":
|
||||||
|
all = self.all_gather(batch_parts)
|
||||||
|
if self.trainer.is_global_zero:
|
||||||
|
self.trainer.my_loss_all = all
|
||||||
|
|
||||||
|
def generate_init_weight(self):
|
||||||
|
print(
|
||||||
|
f"""
|
||||||
|
############################################################################
|
||||||
|
#
|
||||||
|
# Init model weight (slow for large models)...
|
||||||
|
#
|
||||||
|
############################################################################
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
m = {}
|
||||||
|
for n in self.state_dict():
|
||||||
|
p = self.state_dict()[n]
|
||||||
|
shape = p.shape
|
||||||
|
|
||||||
|
gain = 1.0
|
||||||
|
scale = 1.0
|
||||||
|
if (
|
||||||
|
"ln_" in n
|
||||||
|
or ".ln" in n
|
||||||
|
or "time_" in n
|
||||||
|
or "_mask" in n
|
||||||
|
or "pos_emb" in n
|
||||||
|
or ".mask." in n
|
||||||
|
):
|
||||||
|
if "ln_x.weight" in n:
|
||||||
|
layer_scale = (1 + int(n.split(".")[1])) / self.args.n_layer
|
||||||
|
m[n] = (p * 0.0) + (layer_scale**0.7)
|
||||||
|
else:
|
||||||
|
m[n] = p
|
||||||
|
else:
|
||||||
|
if n == "emb.weight":
|
||||||
|
scale = -1 * self.args.lr_init
|
||||||
|
else:
|
||||||
|
if shape[0] > shape[1]:
|
||||||
|
gain = math.sqrt(shape[0] / shape[1])
|
||||||
|
|
||||||
|
zero = [
|
||||||
|
".att.output.",
|
||||||
|
".ffn.value.",
|
||||||
|
".ffn.receptance.",
|
||||||
|
".ffnPre.value.",
|
||||||
|
".ffnPre.receptance.",
|
||||||
|
"head_q.",
|
||||||
|
".oo.",
|
||||||
|
".rr.",
|
||||||
|
]
|
||||||
|
|
||||||
|
for kk in zero:
|
||||||
|
if kk in n:
|
||||||
|
scale = 0
|
||||||
|
if n == "head.weight":
|
||||||
|
scale = 0.5
|
||||||
|
if "head_k." in n:
|
||||||
|
scale = 0.1
|
||||||
|
if "head_q." in n:
|
||||||
|
scale = 0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.accelerator.upper() == "GPU":
|
||||||
|
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
||||||
|
else:
|
||||||
|
m[n] = torch.empty((shape[0], shape[1]))
|
||||||
|
|
||||||
|
if scale == 0:
|
||||||
|
nn.init.zeros_(m[n])
|
||||||
|
elif scale < 0:
|
||||||
|
nn.init.uniform_(m[n], a=scale, b=-scale)
|
||||||
|
else:
|
||||||
|
nn.init.orthogonal_(m[n], gain=gain * scale)
|
||||||
|
|
||||||
|
m[n] = m[n].cpu()
|
||||||
|
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||||
|
m[n] = m[n].half()
|
||||||
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||||
|
m[n] = m[n].bfloat16()
|
||||||
|
|
||||||
|
# if n == "emb.weight":
|
||||||
|
# print(m[n])
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return m
|
310
finetune/lora/v5/src/trainer.py
vendored
Normal file
310
finetune/lora/v5/src/trainer.py
vendored
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
import os, math, time, datetime, subprocess
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||||
|
from .model import LORA_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def my_save(args, trainer, dd, ff):
|
||||||
|
if "14b-run1" in ff:
|
||||||
|
fn = ff.split("/")[-1]
|
||||||
|
fff = "/dev/shm/" + fn
|
||||||
|
torch.save(dd, fff)
|
||||||
|
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
||||||
|
elif ("world/14b" in ff) or ("world/7b" in ff):
|
||||||
|
aa = ff.split("/")[1]
|
||||||
|
fn = ff.split("/")[-1]
|
||||||
|
fff = f"/dev/shm/{aa}-{fn}"
|
||||||
|
torch.save(dd, fff)
|
||||||
|
subprocess.Popen(
|
||||||
|
f" aws s3 mv {fff} s3://rwkv-world/{aa}-{fn} --quiet", shell=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "deepspeed_stage_3" in args.strategy:
|
||||||
|
trainer.save_checkpoint(ff, weights_only=True)
|
||||||
|
else:
|
||||||
|
torch.save(dd, ff)
|
||||||
|
|
||||||
|
|
||||||
|
class train_callback(pl.Callback):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||||
|
args = self.args
|
||||||
|
# if args.cuda_cleanup > 0:
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||||
|
|
||||||
|
# LR schedule
|
||||||
|
w_step = args.warmup_steps
|
||||||
|
if args.lr_final == args.lr_init or args.epoch_count == 0:
|
||||||
|
lr = args.lr_init
|
||||||
|
else:
|
||||||
|
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
|
||||||
|
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
|
||||||
|
progress = (decay_step - w_step + 1) / (decay_total - w_step)
|
||||||
|
progress = min(1, max(0, progress))
|
||||||
|
|
||||||
|
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
||||||
|
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
||||||
|
else: # exp decay
|
||||||
|
lr = args.lr_init * math.exp(
|
||||||
|
math.log(args.lr_final / args.lr_init) * pow(progress, 1)
|
||||||
|
)
|
||||||
|
# if trainer.is_global_zero:
|
||||||
|
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
|
||||||
|
|
||||||
|
if args.my_exit_tokens != 0: # cosine decay
|
||||||
|
real_tokens = real_step * args.ctx_len * args.real_bsz
|
||||||
|
warmup_tokens = w_step * args.ctx_len * args.real_bsz
|
||||||
|
progress = (real_tokens - warmup_tokens) / (
|
||||||
|
abs(args.my_exit_tokens) - warmup_tokens
|
||||||
|
)
|
||||||
|
progress = max(0, min(1, progress))
|
||||||
|
lr_final_factor = args.lr_final / args.lr_init
|
||||||
|
lr_mult = (0.5 + lr_final_factor / 2) + (
|
||||||
|
0.5 - lr_final_factor / 2
|
||||||
|
) * math.cos(math.pi * progress)
|
||||||
|
if args.my_exit_tokens > 0:
|
||||||
|
lr = args.lr_init * lr_mult
|
||||||
|
else:
|
||||||
|
lr = (lr + args.lr_init * lr_mult) / 2
|
||||||
|
if progress >= 1:
|
||||||
|
if (trainer.is_global_zero) or ("deepspeed_stage_3" in args.strategy):
|
||||||
|
my_save(
|
||||||
|
args,
|
||||||
|
trainer,
|
||||||
|
pl_module.state_dict(),
|
||||||
|
f"{args.proj_dir}/rwkv-final.pth",
|
||||||
|
)
|
||||||
|
exit(0)
|
||||||
|
if trainer.global_step < w_step:
|
||||||
|
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
||||||
|
|
||||||
|
if args.weight_decay_final > 0:
|
||||||
|
wd_now = args.weight_decay * math.exp(
|
||||||
|
math.log(args.weight_decay_final / args.weight_decay) * progress
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wd_now = args.weight_decay
|
||||||
|
|
||||||
|
for param_group in trainer.optimizers[0].param_groups:
|
||||||
|
if param_group["weight_decay"] > 0:
|
||||||
|
param_group["weight_decay"] = wd_now
|
||||||
|
if args.layerwise_lr > 0:
|
||||||
|
param_group["lr"] = lr * param_group["my_lr_scale"]
|
||||||
|
# print(param_group["lr"], param_group["my_lr_scale"])
|
||||||
|
else:
|
||||||
|
param_group["lr"] = lr
|
||||||
|
|
||||||
|
trainer.my_lr = lr
|
||||||
|
trainer.my_wd = wd_now
|
||||||
|
# rank_zero_info(f"{real_step} {lr}")
|
||||||
|
|
||||||
|
if trainer.global_step == 0:
|
||||||
|
if trainer.is_global_zero: # logging
|
||||||
|
trainer.my_loss_sum = 0
|
||||||
|
trainer.my_loss_count = 0
|
||||||
|
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
||||||
|
trainer.my_log.write(
|
||||||
|
f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
print(f"\n{trainer.strategy.config}\n")
|
||||||
|
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
trainer.my_log.flush()
|
||||||
|
if len(args.wandb) > 0:
|
||||||
|
print("Login to wandb...")
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
wandb.init(
|
||||||
|
project=args.wandb,
|
||||||
|
name=args.run_name + " " + args.my_timestamp,
|
||||||
|
config=args,
|
||||||
|
save_code=False,
|
||||||
|
)
|
||||||
|
trainer.my_wandb = wandb
|
||||||
|
|
||||||
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||||
|
args = self.args
|
||||||
|
token_per_step = args.ctx_len * args.real_bsz
|
||||||
|
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||||
|
if trainer.is_global_zero: # logging
|
||||||
|
t_now = time.time_ns()
|
||||||
|
kt_s = 0
|
||||||
|
try:
|
||||||
|
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
||||||
|
kt_s = token_per_step / t_cost / 1000
|
||||||
|
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
||||||
|
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
trainer.my_time_ns = t_now
|
||||||
|
if pl.__version__[0] == "2":
|
||||||
|
trainer.my_loss = outputs["loss"]
|
||||||
|
else:
|
||||||
|
trainer.my_loss = trainer.my_loss_all.float().mean().item()
|
||||||
|
trainer.my_loss_sum += trainer.my_loss
|
||||||
|
trainer.my_loss_count += 1
|
||||||
|
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
|
||||||
|
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
|
||||||
|
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
||||||
|
# self.log("s", real_step, prog_bar=True, on_step=True)
|
||||||
|
|
||||||
|
if len(args.wandb) > 0:
|
||||||
|
lll = {
|
||||||
|
"loss": trainer.my_loss,
|
||||||
|
"lr": trainer.my_lr,
|
||||||
|
"wd": trainer.my_wd,
|
||||||
|
"Gtokens": real_step * token_per_step / 1e9,
|
||||||
|
}
|
||||||
|
if kt_s > 0:
|
||||||
|
lll["kt/s"] = kt_s
|
||||||
|
trainer.my_wandb.log(lll, step=int(real_step))
|
||||||
|
if (trainer.is_global_zero) or (
|
||||||
|
"deepspeed_stage_3" in args.strategy
|
||||||
|
): # save pth
|
||||||
|
if args.magic_prime > 0:
|
||||||
|
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
||||||
|
if int(real_step) == int(
|
||||||
|
args.magic_prime * expand_factor // args.real_bsz
|
||||||
|
) - 1 + int(args.my_random_steps):
|
||||||
|
to_save_dict = pl_module.state_dict()
|
||||||
|
my_save(
|
||||||
|
args,
|
||||||
|
trainer,
|
||||||
|
to_save_dict,
|
||||||
|
f"{args.proj_dir}/rwkv-final.pth",
|
||||||
|
)
|
||||||
|
# if args.batch_save==batch_idx :
|
||||||
|
# to_save_dict = pl_module.state_dict()
|
||||||
|
# for name, state in to_save_dict.items():
|
||||||
|
# if 'img' in name:
|
||||||
|
# to_save_dict[name] = state
|
||||||
|
# try:
|
||||||
|
# my_save(
|
||||||
|
# args, trainer,
|
||||||
|
# to_save_dict,
|
||||||
|
# f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}-{batch_idx}.pth",
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# print('Error\n\n', e, '\n\n')
|
||||||
|
|
||||||
|
def on_train_epoch_start(self, trainer, pl_module):
|
||||||
|
args = self.args
|
||||||
|
if pl.__version__[0] == "2":
|
||||||
|
dataset = trainer.train_dataloader.dataset
|
||||||
|
else:
|
||||||
|
dataset = trainer.train_dataloader.dataset.datasets
|
||||||
|
assert "MyDataset" in str(dataset)
|
||||||
|
dataset.global_rank = trainer.global_rank
|
||||||
|
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
|
||||||
|
dataset.world_size = trainer.world_size
|
||||||
|
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
|
args = self.args
|
||||||
|
to_save_dict = {}
|
||||||
|
if (trainer.is_global_zero) or (
|
||||||
|
"deepspeed_stage_3" in args.strategy
|
||||||
|
): # save pth
|
||||||
|
if (
|
||||||
|
args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0
|
||||||
|
) or (trainer.current_epoch == args.epoch_count - 1):
|
||||||
|
if args.data_type == "wds_img":
|
||||||
|
raw_dict = pl_module.state_dict()
|
||||||
|
for k in raw_dict:
|
||||||
|
if k.startswith("encoder.") or k.startswith("decoder."):
|
||||||
|
to_save_dict[k] = raw_dict[k]
|
||||||
|
else:
|
||||||
|
to_save_dict = pl_module.state_dict()
|
||||||
|
|
||||||
|
if args.data_type == "img" and not args.lora:
|
||||||
|
for name, state in to_save_dict.items():
|
||||||
|
if "img" in name:
|
||||||
|
to_save_dict[name] = state
|
||||||
|
|
||||||
|
if args.lora:
|
||||||
|
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||||
|
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||||
|
lora_dict = {}
|
||||||
|
for name, state in to_save_dict.items():
|
||||||
|
if "img" in name:
|
||||||
|
lora_dict[name] = state
|
||||||
|
if (
|
||||||
|
".lora_" in name
|
||||||
|
or (enable_time_finetune and ".time_" in name)
|
||||||
|
or (enable_ln_finetune and ".ln" in name)
|
||||||
|
):
|
||||||
|
lora_dict[name] = state
|
||||||
|
to_save_dict = lora_dict
|
||||||
|
|
||||||
|
try:
|
||||||
|
my_save(
|
||||||
|
args,
|
||||||
|
trainer,
|
||||||
|
to_save_dict,
|
||||||
|
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error\n\n", e, "\n\n")
|
||||||
|
|
||||||
|
if trainer.is_global_zero: # logging
|
||||||
|
trainer.my_log.write(
|
||||||
|
f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n"
|
||||||
|
)
|
||||||
|
trainer.my_log.flush()
|
||||||
|
|
||||||
|
trainer.my_loss_sum = 0
|
||||||
|
trainer.my_loss_count = 0
|
||||||
|
if (args.epoch_begin + trainer.current_epoch) >= args.my_exit:
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def generate_init_weight(model, init_weight_name):
|
||||||
|
mm = model.generate_init_weight()
|
||||||
|
|
||||||
|
if model.args.my_pile_stage == 1:
|
||||||
|
if len(model.args.load_model) > 0:
|
||||||
|
print(f"Combine weights from {model.args.load_model}...")
|
||||||
|
load_dict = torch.load(model.args.load_model, map_location="cpu")
|
||||||
|
for k in load_dict:
|
||||||
|
try:
|
||||||
|
assert k in mm
|
||||||
|
except:
|
||||||
|
print("missing", k)
|
||||||
|
exit(0)
|
||||||
|
src = load_dict[k]
|
||||||
|
try:
|
||||||
|
mm[k] = src.reshape(mm[k].shape)
|
||||||
|
except:
|
||||||
|
tmp = mm[k].squeeze().clone()
|
||||||
|
print(k, src.shape, "-->", mm[k].shape)
|
||||||
|
ss = src.shape[0]
|
||||||
|
dd = tmp.shape[0]
|
||||||
|
for i in range(dd):
|
||||||
|
pos = i / dd * ss
|
||||||
|
if pos >= ss - 1:
|
||||||
|
tmp[i] = src[ss - 1]
|
||||||
|
else:
|
||||||
|
p0 = int(math.floor(pos))
|
||||||
|
ii = pos - p0
|
||||||
|
tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii)
|
||||||
|
mm[k] = tmp.reshape(mm[k].shape)
|
||||||
|
sss = src.squeeze().float().cpu().numpy()
|
||||||
|
print(sss[:10], "...", sss[-10:])
|
||||||
|
mmm = mm[k].squeeze().float().cpu().numpy()
|
||||||
|
print(mmm[:10], "...", mmm[-10:])
|
||||||
|
|
||||||
|
print(f"Save to {init_weight_name}...")
|
||||||
|
torch.save(mm, init_weight_name)
|
||||||
|
|
||||||
|
if model.args.my_pile_stage == 1:
|
||||||
|
print("Done. Now go for stage 2.")
|
||||||
|
exit(0)
|
139
finetune/lora/v5/src/utils.py
vendored
Normal file
139
finetune/lora/v5/src/utils.py
vendored
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import json, time, random, os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
time_slot = {}
|
||||||
|
time_ref = time.time_ns()
|
||||||
|
|
||||||
|
|
||||||
|
def record_time(name):
|
||||||
|
if name not in time_slot:
|
||||||
|
time_slot[name] = 1e20
|
||||||
|
tt = (time.time_ns() - time_ref) / 1e9
|
||||||
|
if tt < time_slot[name]:
|
||||||
|
time_slot[name] = tt
|
||||||
|
|
||||||
|
|
||||||
|
class TOKENIZER:
|
||||||
|
def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"):
|
||||||
|
if "list" in str(type(WORD_NAME)):
|
||||||
|
self.charMode = False
|
||||||
|
if WORD_NAME[0] == WORD_NAME[1]:
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
||||||
|
else:
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
|
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
||||||
|
self.vocab_size = len(self.tokenizer)
|
||||||
|
else:
|
||||||
|
self.charMode = True
|
||||||
|
with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file:
|
||||||
|
self.word_table = json.load(result_file)
|
||||||
|
|
||||||
|
self.vocab_size = len(self.word_table)
|
||||||
|
|
||||||
|
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
||||||
|
self.itos = {int(k): v for k, v in self.word_table.items()}
|
||||||
|
|
||||||
|
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||||
|
|
||||||
|
def refine_context(self, context):
|
||||||
|
context = context.strip().split("\n")
|
||||||
|
for c in range(len(context)):
|
||||||
|
context[c] = context[c].strip().strip("\u3000").strip("\r")
|
||||||
|
context = list(filter(lambda c: c != "", context))
|
||||||
|
context = "\n" + ("\n".join(context)).strip()
|
||||||
|
if context == "":
|
||||||
|
context = "\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
def sample_logits(
|
||||||
|
self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None
|
||||||
|
):
|
||||||
|
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||||
|
lastChar = int(x[-1])
|
||||||
|
|
||||||
|
probs = F.softmax(out, dim=-1)
|
||||||
|
|
||||||
|
if self.charMode:
|
||||||
|
if self.itos[lastChar] == "\n":
|
||||||
|
top_p = top_p_newline
|
||||||
|
else:
|
||||||
|
top_p = top_p_usual
|
||||||
|
else:
|
||||||
|
top_p = top_p_usual
|
||||||
|
|
||||||
|
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
|
||||||
|
probs = probs.numpy()
|
||||||
|
sorted_probs = np.sort(probs)[::-1]
|
||||||
|
cumulative_probs = np.cumsum(sorted_probs)
|
||||||
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||||
|
probs[probs < cutoff] = 0
|
||||||
|
if temperature != 1.0:
|
||||||
|
probs = probs.pow(1.0 / temperature)
|
||||||
|
probs = probs / np.sum(probs)
|
||||||
|
out = np.random.choice(a=len(probs), p=probs)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
sorted_probs = torch.sort(probs, descending=True)[0]
|
||||||
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
||||||
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||||
|
probs[probs < cutoff] = 0
|
||||||
|
if temperature != 1.0:
|
||||||
|
probs = probs.pow(1.0 / temperature)
|
||||||
|
out = torch.multinomial(probs, num_samples=1)[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def MaybeIsPrime(number):
|
||||||
|
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def FermatPrimalityTest(number):
|
||||||
|
if number > 1:
|
||||||
|
for time in range(3):
|
||||||
|
randomNumber = random.randint(2, number) - 1
|
||||||
|
if pow(randomNumber, number - 1, number) != 1:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def MillerRabinPrimalityTest(number):
|
||||||
|
if number == 2:
|
||||||
|
return True
|
||||||
|
elif number == 1 or number % 2 == 0:
|
||||||
|
return False
|
||||||
|
oddPartOfNumber = number - 1
|
||||||
|
timesTwoDividNumber = 0
|
||||||
|
while oddPartOfNumber % 2 == 0:
|
||||||
|
oddPartOfNumber = oddPartOfNumber // 2
|
||||||
|
timesTwoDividNumber = timesTwoDividNumber + 1
|
||||||
|
|
||||||
|
for time in range(3):
|
||||||
|
while True:
|
||||||
|
randomNumber = random.randint(2, number) - 1
|
||||||
|
if randomNumber != 0 and randomNumber != 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
|
||||||
|
|
||||||
|
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
||||||
|
iterationNumber = 1
|
||||||
|
|
||||||
|
while (iterationNumber <= timesTwoDividNumber - 1) and (
|
||||||
|
randomNumberWithPower != number - 1
|
||||||
|
):
|
||||||
|
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
||||||
|
iterationNumber = iterationNumber + 1
|
||||||
|
if randomNumberWithPower != (number - 1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
436
finetune/lora/v5/train.py
vendored
Normal file
436
finetune/lora/v5/train.py
vendored
Normal file
@ -0,0 +1,436 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
rank_zero_info("########## work in progress ##########")
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
||||||
|
parser.add_argument(
|
||||||
|
"--wandb", default="", type=str
|
||||||
|
) # wandb project name. if "" then don't use wandb
|
||||||
|
parser.add_argument("--proj_dir", default="out", type=str)
|
||||||
|
parser.add_argument("--random_seed", default="-1", type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--data_file", default="", type=str)
|
||||||
|
parser.add_argument("--data_type", default="utf-8", type=str)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocab_size", default=0, type=int
|
||||||
|
) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
||||||
|
|
||||||
|
parser.add_argument("--ctx_len", default=1024, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_steps", default=1000, type=int
|
||||||
|
) # a mini "epoch" has [epoch_steps] steps
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_count", default=500, type=int
|
||||||
|
) # train for this many "epochs". will continue afterwards with lr = lr_final
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_begin", default=0, type=int
|
||||||
|
) # if you load a model trained for x "epochs", set epoch_begin = x
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_save", default=5, type=int
|
||||||
|
) # save the model every [epoch_save] "epochs"
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--micro_bsz", default=12, type=int
|
||||||
|
) # micro batch size (batch size per GPU)
|
||||||
|
parser.add_argument("--n_layer", default=6, type=int)
|
||||||
|
parser.add_argument("--n_embd", default=512, type=int)
|
||||||
|
parser.add_argument("--dim_att", default=0, type=int)
|
||||||
|
parser.add_argument("--dim_ffn", default=0, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pre_ffn", default=0, type=int
|
||||||
|
) # replace first att layer by ffn (sometimes better)
|
||||||
|
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
||||||
|
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
||||||
|
parser.add_argument(
|
||||||
|
"--tiny_att_layer", default=-999, type=int
|
||||||
|
) # tiny attention @ which layer
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_init", default=6e-4, type=float
|
||||||
|
) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
||||||
|
parser.add_argument("--lr_final", default=1e-5, type=float)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup_steps", default=-1, type=int
|
||||||
|
) # try 50 if you load a model
|
||||||
|
parser.add_argument("--beta1", default=0.9, type=float)
|
||||||
|
parser.add_argument(
|
||||||
|
"--beta2", default=0.99, type=float
|
||||||
|
) # use 0.999 when your model is close to convergence
|
||||||
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grad_cp", default=0, type=int
|
||||||
|
) # gradient checkpt: saves VRAM, but slower
|
||||||
|
parser.add_argument(
|
||||||
|
"--dropout", default=0, type=float
|
||||||
|
) # try 0.01 / 0.02 / 0.05 / 0.1
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight_decay", default=0, type=float
|
||||||
|
) # try 0.1 / 0.01 / 0.001
|
||||||
|
parser.add_argument("--weight_decay_final", default=-1, type=float)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--my_pile_version", default=1, type=int
|
||||||
|
) # my special pile version
|
||||||
|
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
||||||
|
parser.add_argument(
|
||||||
|
"--my_pile_shift", default=-1, type=int
|
||||||
|
) # my special pile mode - text shift
|
||||||
|
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--layerwise_lr", default=1, type=int
|
||||||
|
) # layerwise lr for faster convergence (but slower it/s)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_bucket_mb", default=200, type=int
|
||||||
|
) # deepspeed bucket size in MB. 200 seems enough
|
||||||
|
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
||||||
|
|
||||||
|
parser.add_argument("--my_sample_len", default=0, type=int)
|
||||||
|
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
||||||
|
parser.add_argument("--my_att_shift", default=1, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--head_size_a", default=64, type=int
|
||||||
|
) # can try larger values for larger models
|
||||||
|
parser.add_argument("--head_size_divisor", default=8, type=int)
|
||||||
|
parser.add_argument("--my_pos_emb", default=0, type=int)
|
||||||
|
parser.add_argument("--load_partial", default=0, type=int)
|
||||||
|
parser.add_argument("--magic_prime", default=0, type=int)
|
||||||
|
parser.add_argument("--my_qa_mask", default=0, type=int)
|
||||||
|
parser.add_argument("--my_random_steps", default=0, type=int)
|
||||||
|
parser.add_argument("--my_testing", default="", type=str)
|
||||||
|
parser.add_argument("--my_exit", default=99999999, type=int)
|
||||||
|
parser.add_argument("--my_exit_tokens", default=0, type=int)
|
||||||
|
|
||||||
|
# LORA
|
||||||
|
parser.add_argument("--emb", action="store_true")
|
||||||
|
parser.add_argument("--lora", action="store_true")
|
||||||
|
parser.add_argument("--lora_load", default="", type=str)
|
||||||
|
parser.add_argument("--lora_r", default=8, type=int)
|
||||||
|
parser.add_argument("--lora_alpha", default=32, type=float)
|
||||||
|
parser.add_argument("--lora_dropout", default=0.01, type=float)
|
||||||
|
parser.add_argument("--lora_parts", default="att,ln,time", type=str)
|
||||||
|
|
||||||
|
if pl.__version__[0] == "2":
|
||||||
|
parser.add_argument("--accelerator", default="gpu", type=str)
|
||||||
|
parser.add_argument("--strategy", default="auto", type=str)
|
||||||
|
parser.add_argument("--devices", default=1, type=int)
|
||||||
|
parser.add_argument("--num_nodes", default=1, type=int)
|
||||||
|
parser.add_argument("--precision", default="fp16", type=str)
|
||||||
|
parser.add_argument("--accumulate_grad_batches", default=1, type=int)
|
||||||
|
else:
|
||||||
|
parser = Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import os, warnings, math, datetime, sys, time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
if "deepspeed" in args.strategy:
|
||||||
|
import deepspeed
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
|
if args.random_seed >= 0:
|
||||||
|
print(
|
||||||
|
f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n"
|
||||||
|
* 3
|
||||||
|
)
|
||||||
|
seed_everything(args.random_seed)
|
||||||
|
|
||||||
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", ".*Consider increasing the value of the `num_workers` argument*"
|
||||||
|
)
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", ".*The progress bar already tracks a metric with the*"
|
||||||
|
)
|
||||||
|
# os.environ["WDS_SHOW_SEED"] = "1"
|
||||||
|
|
||||||
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
|
args.enable_checkpointing = False
|
||||||
|
args.replace_sampler_ddp = False
|
||||||
|
args.logger = False
|
||||||
|
args.gradient_clip_val = 1.0
|
||||||
|
args.num_sanity_val_steps = 0
|
||||||
|
args.check_val_every_n_epoch = int(1e20)
|
||||||
|
args.log_every_n_steps = int(1e20)
|
||||||
|
args.max_epochs = args.epoch_count # -1 continue forever
|
||||||
|
args.betas = (args.beta1, args.beta2)
|
||||||
|
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
||||||
|
os.environ["RWKV_MY_TESTING"] = args.my_testing
|
||||||
|
os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
|
||||||
|
if args.dim_att <= 0:
|
||||||
|
args.dim_att = args.n_embd
|
||||||
|
if args.dim_ffn <= 0:
|
||||||
|
args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
|
||||||
|
|
||||||
|
if args.data_type == "wds_img":
|
||||||
|
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
||||||
|
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
||||||
|
else:
|
||||||
|
args.run_name = (
|
||||||
|
f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
||||||
|
)
|
||||||
|
if not os.path.exists(args.proj_dir):
|
||||||
|
os.makedirs(args.proj_dir)
|
||||||
|
|
||||||
|
if args.my_pile_stage > 0:
|
||||||
|
magic_prime_bak = args.magic_prime
|
||||||
|
|
||||||
|
if args.my_pile_shift < 0:
|
||||||
|
args.my_pile_shift = 0
|
||||||
|
|
||||||
|
if magic_prime_bak > 0:
|
||||||
|
args.magic_prime = magic_prime_bak
|
||||||
|
if args.my_qa_mask == 2:
|
||||||
|
args.epoch_count = 2 * args.magic_prime // 40320
|
||||||
|
else:
|
||||||
|
args.epoch_count = args.magic_prime // 40320
|
||||||
|
|
||||||
|
args.epoch_steps = 40320 // args.real_bsz
|
||||||
|
assert args.epoch_steps * args.real_bsz == 40320
|
||||||
|
# if args.my_pile_stage == 2:
|
||||||
|
# assert args.lr_final == args.lr_init
|
||||||
|
if args.my_pile_stage >= 2: # find latest saved model
|
||||||
|
list_p = []
|
||||||
|
for p in os.listdir(args.proj_dir):
|
||||||
|
if p.startswith("rwkv") and p.endswith(".pth"):
|
||||||
|
p = ((p.split("-"))[1].split("."))[0]
|
||||||
|
if p != "final":
|
||||||
|
if p == "init":
|
||||||
|
p = -1
|
||||||
|
else:
|
||||||
|
p = int(p)
|
||||||
|
list_p += [p]
|
||||||
|
list_p.sort()
|
||||||
|
max_p = list_p[-1]
|
||||||
|
if len(list_p) > 1:
|
||||||
|
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
|
||||||
|
if max_p == -1:
|
||||||
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||||
|
else:
|
||||||
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||||
|
if args.warmup_steps < 0:
|
||||||
|
if args.my_pile_stage == 2:
|
||||||
|
args.warmup_steps = 10
|
||||||
|
else:
|
||||||
|
args.warmup_steps = 30
|
||||||
|
args.epoch_begin = max_p + 1
|
||||||
|
|
||||||
|
samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||||
|
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
||||||
|
try:
|
||||||
|
deepspeed_version = deepspeed.__version__
|
||||||
|
except:
|
||||||
|
deepspeed_version = None
|
||||||
|
pass
|
||||||
|
rank_zero_info(
|
||||||
|
f"""
|
||||||
|
############################################################################
|
||||||
|
#
|
||||||
|
# RWKV-5 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
|
||||||
|
#
|
||||||
|
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
||||||
|
#
|
||||||
|
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1}, save every {args.epoch_save} epoch
|
||||||
|
#
|
||||||
|
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
||||||
|
#
|
||||||
|
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
||||||
|
#
|
||||||
|
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
||||||
|
#
|
||||||
|
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
|
||||||
|
# Found deepspeed {deepspeed_version}, recommend 0.7.0 (faster than newer versions)
|
||||||
|
# Found pytorch_lightning {pl.__version__}, recommend 1.9.5
|
||||||
|
#
|
||||||
|
############################################################################
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
rank_zero_info(str(vars(args)) + "\n")
|
||||||
|
|
||||||
|
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "uint16"]
|
||||||
|
|
||||||
|
if args.lr_final == 0 or args.lr_init == 0:
|
||||||
|
rank_zero_info(
|
||||||
|
"\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
||||||
|
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
||||||
|
if args.precision == "fp32":
|
||||||
|
for i in range(10):
|
||||||
|
rank_zero_info(
|
||||||
|
"\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n"
|
||||||
|
)
|
||||||
|
if args.precision == "fp16":
|
||||||
|
rank_zero_info(
|
||||||
|
"\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["RWKV_JIT_ON"] = "0"
|
||||||
|
if "deepspeed_stage_3" in args.strategy:
|
||||||
|
os.environ["RWKV_JIT_ON"] = "0"
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
if args.precision == "fp32":
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
else:
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
if "32" in args.precision:
|
||||||
|
args.precision = 32
|
||||||
|
elif args.precision == "fp16":
|
||||||
|
args.precision = 16
|
||||||
|
else:
|
||||||
|
args.precision = "bf16"
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
from src.trainer import train_callback, generate_init_weight
|
||||||
|
from src.dataset import MyDataset
|
||||||
|
|
||||||
|
train_data = MyDataset(args)
|
||||||
|
args.vocab_size = train_data.vocab_size
|
||||||
|
|
||||||
|
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
||||||
|
|
||||||
|
if args.lora:
|
||||||
|
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
||||||
|
LORA_CONFIG["r"] = args.lora_r
|
||||||
|
LORA_CONFIG["alpha"] = args.lora_alpha
|
||||||
|
LORA_CONFIG["dropout"] = args.lora_dropout
|
||||||
|
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
|
||||||
|
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||||
|
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||||
|
model = RWKV(args)
|
||||||
|
# only train lora parameters
|
||||||
|
if args.lora:
|
||||||
|
model.requires_grad_(False)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
|
||||||
|
print(f" LoRA additionally training module {name}")
|
||||||
|
for pname, param in module.named_parameters():
|
||||||
|
param.requires_grad = "lora_" in pname
|
||||||
|
elif enable_ln_finetune and ".ln" in name:
|
||||||
|
print(f" LoRA additionally training module {name}")
|
||||||
|
for param in module.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
elif enable_time_finetune and any(
|
||||||
|
n.startswith("time") for n, _ in module.named_parameters()
|
||||||
|
):
|
||||||
|
for pname, param in module.named_parameters():
|
||||||
|
if pname.startswith("time"):
|
||||||
|
print(f" LoRA additionally training parameter {pname}")
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(args.load_model) == 0 or args.my_pile_stage == 1
|
||||||
|
): # shall we build the initial weights?
|
||||||
|
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
||||||
|
generate_init_weight(model, init_weight_name) # save initial weights
|
||||||
|
args.load_model = init_weight_name
|
||||||
|
|
||||||
|
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
||||||
|
try:
|
||||||
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||||
|
load_keys = list(load_dict.keys())
|
||||||
|
for k in load_keys:
|
||||||
|
if k.startswith("_forward_module."):
|
||||||
|
load_dict[k.replace("_forward_module.", "")] = load_dict[k]
|
||||||
|
del load_dict[k]
|
||||||
|
except:
|
||||||
|
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
||||||
|
if args.my_pile_stage >= 2: # try again using another checkpoint
|
||||||
|
max_p = args.my_pile_prev_p
|
||||||
|
if max_p == -1:
|
||||||
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||||
|
else:
|
||||||
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||||
|
args.epoch_begin = max_p + 1
|
||||||
|
rank_zero_info(f"Trying {args.load_model}")
|
||||||
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||||
|
|
||||||
|
if args.load_partial == 1:
|
||||||
|
load_keys = load_dict.keys()
|
||||||
|
for k in model.state_dict():
|
||||||
|
if k not in load_keys:
|
||||||
|
load_dict[k] = model.state_dict()[k]
|
||||||
|
# model.load_state_dict(load_dict)
|
||||||
|
|
||||||
|
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||||
|
if os.path.isfile(args.lora_load):
|
||||||
|
model.load_state_dict(
|
||||||
|
torch.load(args.lora_load, map_location="cpu"), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if pl.__version__[0] == "2":
|
||||||
|
trainer = Trainer(
|
||||||
|
accelerator=args.accelerator,
|
||||||
|
strategy=args.strategy,
|
||||||
|
devices=args.devices,
|
||||||
|
num_nodes=args.num_nodes,
|
||||||
|
precision=args.precision,
|
||||||
|
logger=args.logger,
|
||||||
|
callbacks=[train_callback(args)],
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
||||||
|
num_sanity_val_steps=args.num_sanity_val_steps,
|
||||||
|
log_every_n_steps=args.log_every_n_steps,
|
||||||
|
enable_checkpointing=args.enable_checkpointing,
|
||||||
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
gradient_clip_val=args.gradient_clip_val,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
trainer = Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[train_callback(args)],
|
||||||
|
)
|
||||||
|
|
||||||
|
if trainer.global_rank == 0:
|
||||||
|
for n in model.state_dict():
|
||||||
|
shape = model.state_dict()[n].shape
|
||||||
|
shape = [i for i in shape if i != 1]
|
||||||
|
if len(shape) > 1:
|
||||||
|
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
|
||||||
|
else:
|
||||||
|
print(f"{str(shape[0]).ljust(5)} {n}")
|
||||||
|
|
||||||
|
if "deepspeed" in args.strategy:
|
||||||
|
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
|
||||||
|
args.ds_bucket_mb * 1000 * 1000
|
||||||
|
)
|
||||||
|
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
|
||||||
|
args.ds_bucket_mb * 1000 * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
||||||
|
data_loader = DataLoader(
|
||||||
|
train_data,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=True,
|
||||||
|
batch_size=args.micro_bsz,
|
||||||
|
num_workers=1,
|
||||||
|
persistent_workers=False,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.fit(model, data_loader)
|
@ -131,7 +131,7 @@ const showError = (e: any) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const errorsMap = Object.entries({
|
const errorsMap = Object.entries({
|
||||||
'python3 ./finetune/lora/train.py': 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.',
|
'python3 ./finetune/lora/v': 'Memory is not enough, try to increase the virtual memory (Swap of WSL) or use a smaller base model.',
|
||||||
'cuda out of memory': 'VRAM is not enough',
|
'cuda out of memory': 'VRAM is not enough',
|
||||||
'valueerror: high <= 0': 'Training data is not enough, reduce context length or add more data for training',
|
'valueerror: high <= 0': 'Training data is not enough, reduce context length or add more data for training',
|
||||||
'+= \'+ptx\'': 'Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"',
|
'+= \'+ptx\'': 'Can not find an Nvidia GPU. Perhaps the gpu driver of windows is too old, or you are using WSL 1 for training, please upgrade to WSL 2. e.g. Run "wsl --set-version Ubuntu-22.04 2"',
|
||||||
@ -299,7 +299,6 @@ const LoraFinetune: FC = observer(() => {
|
|||||||
(loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') +
|
(loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') +
|
||||||
(loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
|
(loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
|
||||||
`--data_file ${convertedDataPath} ` +
|
`--data_file ${convertedDataPath} ` +
|
||||||
`--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` +
|
|
||||||
`--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
|
`--ctx_len ${ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
|
||||||
`--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
|
`--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
|
||||||
`--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
|
`--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
|
||||||
|
Loading…
Reference in New Issue
Block a user