lora finetune (need to be refactored)
This commit is contained in:
7
finetune/data/sample.jsonl
Normal file
7
finetune/data/sample.jsonl
Normal file
@@ -0,0 +1,7 @@
|
||||
{"text": "1:This is the first document."}
|
||||
{"text": "2:Hello\nWorld"}
|
||||
{"text": "3:1+1=2\n1+2=3\n2+2=4"}
|
||||
{"text": "4:You will be training the GPT version because it's paralleziable and faster to train."}
|
||||
{"text": "5:Read the inference code in src/model.py and try using the final hidden state(.xx .aa .bb)"}
|
||||
{"text": "6:You can fine-tune the model with longer ctxLen and it can quickly adapt to longer ctxLens."}
|
||||
{"text": "7:Consider RWKV 14B. The state has 200 vectors, that is, 5 vectors for each block: fp16 (xx), fp32 (aa), fp32 (bb), fp32 (pp), fp16 (xx)."}
|
||||
41
finetune/get_layer_and_embd.py
Normal file
41
finetune/get_layer_and_embd.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import threading
|
||||
import gc
|
||||
|
||||
|
||||
def file_cleaner(file):
|
||||
last_pos = 0
|
||||
|
||||
def cleaner():
|
||||
nonlocal last_pos
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
pos = file.tell()
|
||||
if pos > last_pos:
|
||||
os.posix_fadvise(
|
||||
file.fileno(), last_pos, pos - last_pos, os.POSIX_FADV_DONTNEED
|
||||
)
|
||||
last_pos = pos
|
||||
|
||||
return cleaner
|
||||
|
||||
|
||||
model_file = open(sys.argv[1], "rb")
|
||||
cleaner = file_cleaner(model_file)
|
||||
cleaner_thread = threading.Thread(target=cleaner, daemon=True)
|
||||
cleaner_thread.start()
|
||||
|
||||
w = torch.load(model_file, map_location="cpu")
|
||||
gc.collect()
|
||||
|
||||
n_embd = w["emb.weight"].shape[1]
|
||||
n_layer = 0
|
||||
keys = list(w.keys())
|
||||
for x in keys:
|
||||
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
|
||||
n_layer = max(n_layer, layer_id + 1)
|
||||
|
||||
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
|
||||
46
finetune/install-wsl-dep-and-train.sh
Normal file
46
finetune/install-wsl-dep-and-train.sh
Normal file
@@ -0,0 +1,46 @@
|
||||
if [[ ${cnMirror} == 1 ]]; then
|
||||
export PIP_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
if grep -q "mirrors.aliyun.com" /etc/apt/sources.list; then
|
||||
echo "apt cnMirror already set"
|
||||
else
|
||||
sudo sed -i 's/http:\/\/archive.ubuntu.com\/ubuntu\//http:\/\/mirrors.aliyun.com\/ubuntu\//g' /etc/apt/sources.list
|
||||
sudo apt update
|
||||
fi
|
||||
fi
|
||||
|
||||
if dpkg -s "python3-pip" >/dev/null 2>&1; then
|
||||
echo "pip installed"
|
||||
else
|
||||
sudo apt install python3-pip
|
||||
fi
|
||||
|
||||
if dpkg -s "ninja-build" >/dev/null 2>&1; then
|
||||
echo "ninja installed"
|
||||
else
|
||||
sudo apt install ninja-build
|
||||
fi
|
||||
|
||||
if dpkg -s "cuda" >/dev/null 2>&1; then
|
||||
echo "cuda installed"
|
||||
else
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin
|
||||
sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb
|
||||
sudo dpkg -i cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb
|
||||
sudo cp /var/cuda-repo-wsl-ubuntu-11-7-local/cuda-*-keyring.gpg /usr/share/keyrings/
|
||||
sudo apt-get update
|
||||
sudo apt-get -y install cuda
|
||||
fi
|
||||
|
||||
if python3 -c "import pkg_resources; pkg_resources.require(open('./finetune/requirements.txt',mode='r'))" &>/dev/null; then
|
||||
echo "requirements satisfied"
|
||||
else
|
||||
python3 -m pip install -r ./finetune/requirements.txt
|
||||
fi
|
||||
|
||||
echo "loading $loadModel"
|
||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel)
|
||||
echo $modelInfo
|
||||
|
||||
python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
|
||||
--lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu
|
||||
597
finetune/json2binidx_tool/tools/indexed_dataset.py
vendored
Normal file
597
finetune/json2binidx_tool/tools/indexed_dataset.py
vendored
Normal file
@@ -0,0 +1,597 @@
|
||||
# Copyright (c) 2021, EleutherAI
|
||||
# This file is based on code by the authors denoted below and has been modified from its original version.
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# copied from fairseq/fairseq/data/indexed_dataset.py
|
||||
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
||||
# other slight modifications to remove fairseq dependencies
|
||||
# Added document index to index file and made it accessible.
|
||||
# An empty sentence no longer separates documents.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
|
||||
def __best_fitting_dtype(vocab_size=None):
|
||||
if vocab_size is not None and vocab_size < 65500:
|
||||
return np.uint16
|
||||
else:
|
||||
return np.int32
|
||||
|
||||
|
||||
def infer_dataset_impl(path):
|
||||
if IndexedDataset.exists(path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
if magic == IndexedDataset._HDR_MAGIC:
|
||||
return "cached"
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return "mmap"
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print(
|
||||
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def make_builder(out_file, impl, vocab_size=None):
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDatasetBuilder(
|
||||
out_file, dtype=__best_fitting_dtype(vocab_size)
|
||||
)
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
|
||||
|
||||
def make_dataset(path, impl, skip_warmup=False):
|
||||
if not IndexedDataset.exists(path):
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print(
|
||||
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
|
||||
)
|
||||
return None
|
||||
if impl == "infer":
|
||||
impl = infer_dataset_impl(path)
|
||||
if impl == "lazy" and IndexedDataset.exists(path):
|
||||
return IndexedDataset(path)
|
||||
elif impl == "cached" and IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path)
|
||||
elif impl == "mmap" and MMapIndexedDataset.exists(path):
|
||||
return MMapIndexedDataset(path, skip_warmup)
|
||||
print(f"Unknown dataset implementation: {impl}")
|
||||
return None
|
||||
|
||||
|
||||
def dataset_exists(path, impl):
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
|
||||
|
||||
def read_longs(f, n):
|
||||
a = np.empty(n, dtype=np.int64)
|
||||
f.readinto(a)
|
||||
return a
|
||||
|
||||
|
||||
def write_longs(f, a):
|
||||
f.write(np.array(a, dtype=np.int64))
|
||||
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: np.float32,
|
||||
7: np.float64,
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
def create_doc_idx(sizes):
|
||||
doc_idx = [0]
|
||||
for i, s in enumerate(sizes):
|
||||
if s == 0:
|
||||
doc_idx.append(i + 1)
|
||||
return doc_idx
|
||||
|
||||
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
"""Loader for IndexedDataset"""
|
||||
|
||||
_HDR_MAGIC = b"TNTIDX\x00\x00"
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.data_file = None
|
||||
self.read_index(path)
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
assert magic == self._HDR_MAGIC, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = f.read(8)
|
||||
assert struct.unpack("<Q", version) == (1,)
|
||||
code, self.element_size = struct.unpack("<QQ", f.read(16))
|
||||
self.dtype = dtypes[code]
|
||||
self._len, self.s = struct.unpack("<QQ", f.read(16))
|
||||
self.doc_count = struct.unpack("<Q", f.read(8))
|
||||
self.dim_offsets = read_longs(f, self._len + 1)
|
||||
self.data_offsets = read_longs(f, self._len + 1)
|
||||
self.sizes = read_longs(f, self.s)
|
||||
self.doc_idx = read_longs(f, self.doc_count)
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), "rb", buffering=0)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self._len:
|
||||
raise IndexError("index out of range")
|
||||
|
||||
def __del__(self):
|
||||
if self.data_file:
|
||||
self.data_file.close()
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
|
||||
size = sum(sizes)
|
||||
a = np.empty(size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
offsets = list(accumulate(sizes))
|
||||
sents = np.split(a, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
def __init__(self, path):
|
||||
super().__init__(path)
|
||||
self.cache = None
|
||||
self.cache_index = {}
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return True
|
||||
|
||||
def prefetch(self, indices):
|
||||
if all(i in self.cache_index for i in indices):
|
||||
return
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
indices = sorted(set(indices))
|
||||
total_size = 0
|
||||
for i in indices:
|
||||
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
self.cache = np.empty(total_size, dtype=self.dtype)
|
||||
ptx = 0
|
||||
self.cache_index.clear()
|
||||
for i in indices:
|
||||
self.cache_index[i] = ptx
|
||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
a = self.cache[ptx : ptx + size]
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
ptx += size
|
||||
if self.data_file:
|
||||
# close and delete data file after prefetch so we can pickle
|
||||
self.data_file.close()
|
||||
self.data_file = None
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
ptx = self.cache_index[i]
|
||||
np.copyto(a, self.cache[ptx : ptx + a.size])
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
# Hack just to make this work, can optimizer later if necessary
|
||||
sents = []
|
||||
for i in range(*idx.indices(len(self))):
|
||||
sents.append(self[i])
|
||||
return sents
|
||||
|
||||
|
||||
class IndexedDatasetBuilder(object):
|
||||
element_sizes = {
|
||||
np.uint8: 1,
|
||||
np.int8: 1,
|
||||
np.int16: 2,
|
||||
np.int32: 4,
|
||||
np.int64: 8,
|
||||
np.float32: 4,
|
||||
np.float64: 8,
|
||||
}
|
||||
|
||||
def __init__(self, out_file, dtype=np.int32):
|
||||
self.out_file = open(out_file, "wb")
|
||||
self.dtype = dtype
|
||||
self.data_offsets = [0]
|
||||
self.dim_offsets = [0]
|
||||
self.sizes = []
|
||||
self.element_size = self.element_sizes[self.dtype]
|
||||
self.doc_idx = [0]
|
||||
|
||||
def add_item(self, np_array):
|
||||
assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype
|
||||
bytes = self.out_file.write(np_array)
|
||||
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
||||
for s in np_array.shape:
|
||||
self.sizes.append(s)
|
||||
self.dim_offsets.append(self.dim_offsets[-1] + len(np_array.shape))
|
||||
|
||||
def end_document(self):
|
||||
self.doc_idx.append(len(self.sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
index = IndexedDataset(another_file)
|
||||
assert index.dtype == self.dtype
|
||||
|
||||
begin = self.data_offsets[-1]
|
||||
for offset in index.data_offsets[1:]:
|
||||
self.data_offsets.append(begin + offset)
|
||||
self.sizes.extend(index.sizes)
|
||||
begin = self.dim_offsets[-1]
|
||||
for dim_offset in index.dim_offsets[1:]:
|
||||
self.dim_offsets.append(begin + dim_offset)
|
||||
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024)
|
||||
if data:
|
||||
self.out_file.write(data)
|
||||
else:
|
||||
break
|
||||
|
||||
def finalize(self, index_file):
|
||||
self.out_file.close()
|
||||
index = open(index_file, "wb")
|
||||
index.write(b"TNTIDX\x00\x00")
|
||||
index.write(struct.pack("<Q", 1))
|
||||
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
|
||||
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack("<Q", len(self.doc_idx)))
|
||||
write_longs(index, self.dim_offsets)
|
||||
write_longs(index, self.data_offsets)
|
||||
write_longs(index, self.sizes)
|
||||
write_longs(index, self.doc_idx)
|
||||
index.close()
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
with open(path, "rb") as stream:
|
||||
while stream.read(100 * 1024 * 1024):
|
||||
pass
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
# Write Magic string so we can check the file format then opening it again.
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
# Write version number
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
# Little endian unsigned 8 Bit integer
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
pointers = np.zeros(len(sizes), dtype=np.int64)
|
||||
sizes = np.array(sizes, dtype=np.int64)
|
||||
|
||||
np.cumsum(sizes[:-1], out=pointers[1:])
|
||||
pointers = pointers * dtype().itemsize
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
# Little endian unsigned 64 Bit integer
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
# Little endian unsigned 8 Bit integer
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
print(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
print(" reading sizes...")
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
print(" reading pointers...")
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
print(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
print(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(
|
||||
data_file_path(self._path), mode="r", order="C"
|
||||
)
|
||||
print(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||
)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder(object):
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, "wb")
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
self._doc_idx = [0]
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def add_item(self, np_array):
|
||||
assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype
|
||||
self._data_file.write(np_array.tobytes(order="C"))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def end_document(self):
|
||||
self._doc_idx.append(len(self._sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
# Concatenate index
|
||||
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
||||
assert index.dtype == self._dtype
|
||||
|
||||
for size in index.sizes:
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
self._data_file.close()
|
||||
|
||||
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
||||
index.write(self._sizes, self._doc_idx)
|
||||
243
finetune/json2binidx_tool/tools/preprocess_data.py
vendored
Normal file
243
finetune/json2binidx_tool/tools/preprocess_data.py
vendored
Normal file
@@ -0,0 +1,243 @@
|
||||
# Copyright (c) 2021, EleutherAI
|
||||
# This file is based on code by the authors denoted below and has been modified from its original version.
|
||||
#
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Processing data for pretraining."""
|
||||
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
|
||||
import lm_dataformat as lmd
|
||||
import numpy as np
|
||||
|
||||
sys.path.append(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
||||
)
|
||||
import time
|
||||
import tqdm
|
||||
import ftfy
|
||||
|
||||
from tokenizer import build_tokenizer
|
||||
import indexed_dataset
|
||||
from threading import Semaphore
|
||||
|
||||
|
||||
class Encoder(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def initializer(self):
|
||||
# Use Encoder class as a container for global data
|
||||
Encoder.tokenizer = build_tokenizer(self.args)
|
||||
|
||||
def encode(self, text):
|
||||
if self.args.ftfy:
|
||||
text = ftfy.fix_text(text)
|
||||
ids = {}
|
||||
for key in self.args.jsonl_keys:
|
||||
doc_ids = []
|
||||
text_ids = Encoder.tokenizer.tokenize(text)
|
||||
if len(text_ids) > 0:
|
||||
doc_ids.append(text_ids)
|
||||
if self.args.append_eod:
|
||||
doc_ids[-1].append(Encoder.tokenizer.eod)
|
||||
ids[key] = doc_ids
|
||||
return ids, len(text)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group(title="input data")
|
||||
group.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma separated "
|
||||
"list",
|
||||
)
|
||||
group.add_argument(
|
||||
"--jsonl-keys",
|
||||
nargs="+",
|
||||
default=["text"],
|
||||
help="space separate listed of keys to extract from jsonl. Defa",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-docs",
|
||||
default=None,
|
||||
help="Optional: Number of documents in the input data (if known) for an accurate progress bar.",
|
||||
type=int,
|
||||
)
|
||||
group = parser.add_argument_group(title="tokenizer")
|
||||
group.add_argument(
|
||||
"--tokenizer-type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"HFGPT2Tokenizer",
|
||||
"HFTokenizer",
|
||||
"GPT2BPETokenizer",
|
||||
"CharLevelTokenizer",
|
||||
"TiktokenTokenizer",
|
||||
"RWKVTokenizer",
|
||||
],
|
||||
help="What type of tokenizer to use.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--vocab-file", type=str, default=None, help="Path to the vocab file"
|
||||
)
|
||||
group.add_argument(
|
||||
"--merge-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the BPE merge file (if necessary).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--append-eod",
|
||||
action="store_true",
|
||||
help="Append an <eod> token to the end of a document.",
|
||||
)
|
||||
group.add_argument("--ftfy", action="store_true", help="Use ftfy to clean text")
|
||||
group = parser.add_argument_group(title="output data")
|
||||
group.add_argument(
|
||||
"--output-prefix",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to binary output file without suffix",
|
||||
)
|
||||
group.add_argument(
|
||||
"--dataset-impl",
|
||||
type=str,
|
||||
default="mmap",
|
||||
choices=["lazy", "cached", "mmap"],
|
||||
help="Dataset implementation to use. Default: mmap",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group(title="runtime")
|
||||
group.add_argument(
|
||||
"--workers", type=int, default=1, help="Number of worker processes to launch"
|
||||
)
|
||||
group.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Interval between progress updates",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.keep_empty = False
|
||||
|
||||
# some default/dummy values for the tokenizer
|
||||
args.rank = 0
|
||||
args.make_vocab_size_divisible_by = 128
|
||||
args.model_parallel_size = 1
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def yield_from_files(fnames: list, semaphore):
|
||||
"""
|
||||
Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts /
|
||||
other compressed formats. Also filters out empty documents.
|
||||
|
||||
:param fnames: list of filenames
|
||||
"""
|
||||
|
||||
def yielder(fname, semaphore):
|
||||
for f in filter(lambda x: x, lmd.Reader(fname).stream_data()):
|
||||
semaphore.acquire()
|
||||
yield f
|
||||
|
||||
for fname in fnames:
|
||||
semaphore.acquire()
|
||||
|
||||
yield from yielder(fname, semaphore)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
encoder = Encoder(args)
|
||||
tokenizer = build_tokenizer(args)
|
||||
print(f"Vocab size: {tokenizer.vocab_size}")
|
||||
print(f"Output prefix: {args.output_prefix}")
|
||||
|
||||
# build a semaphore object to stop `yield_from_files` from getting ahead of encoder.encode and
|
||||
# hence building up memory
|
||||
semaphore = Semaphore(10000 + args.workers)
|
||||
|
||||
# use multiprocessing to iterate over input documents
|
||||
fin = yield_from_files(args.input.split(","), semaphore)
|
||||
|
||||
if args.workers > 1:
|
||||
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
|
||||
encoded_docs = pool.imap(encoder.encode, fin, chunksize=25)
|
||||
else:
|
||||
encoder.initializer()
|
||||
encoded_docs = (encoder.encode(doc) for doc in fin)
|
||||
|
||||
# make a dataset builder for each key in args.jsonl_keys
|
||||
# each key will output to a different file beginning with args.output_prefix
|
||||
output_bin_files = {}
|
||||
output_idx_files = {}
|
||||
builders = {}
|
||||
for key in args.jsonl_keys:
|
||||
output_bin_files[key] = "{}_{}_{}.bin".format(
|
||||
args.output_prefix, key, "document"
|
||||
)
|
||||
output_idx_files[key] = "{}_{}_{}.idx".format(
|
||||
args.output_prefix, key, "document"
|
||||
)
|
||||
builders[key] = indexed_dataset.make_builder(
|
||||
output_bin_files[key],
|
||||
impl=args.dataset_impl,
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
)
|
||||
|
||||
# actually do tokenization
|
||||
proc_start = time.time()
|
||||
total_bytes_processed = 0
|
||||
pbar = tqdm.tqdm()
|
||||
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
|
||||
total_bytes_processed += bytes_processed
|
||||
|
||||
# release semaphore so `yield_from_files` can add another file to the buffer
|
||||
semaphore.release()
|
||||
|
||||
# add each tokenized document / sentence
|
||||
for key, sentences in doc.items():
|
||||
for sentence in sentences:
|
||||
builders[key].add_item(np.array(sentence, dtype=builders[key].dtype))
|
||||
# separate with eos token
|
||||
builders[key].end_document()
|
||||
|
||||
# log progress
|
||||
if i % args.log_interval == 0:
|
||||
current = time.time()
|
||||
elapsed = current - proc_start
|
||||
mbs = total_bytes_processed / elapsed / 1024 / 1024
|
||||
pbar.set_description(
|
||||
f"Processed {i}{'' if args.num_docs is None else '/' + str(args.num_docs)} documents ({i / elapsed:0.2f} docs/s, {mbs:0.2f} MB/s)."
|
||||
)
|
||||
if i != 0:
|
||||
pbar.update(args.log_interval)
|
||||
|
||||
# save output file
|
||||
for key in args.jsonl_keys:
|
||||
builders[key].finalize(output_idx_files[key])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
232
finetune/json2binidx_tool/tools/rwkv_tokenizer.py
vendored
Normal file
232
finetune/json2binidx_tool/tools/rwkv_tokenizer.py
vendored
Normal file
@@ -0,0 +1,232 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
# Source: https://github.com/BlinkDL/ChatRWKV/blob/main/tokenizer/rwkv_tokenizer.py
|
||||
########################################################################################################
|
||||
|
||||
import os, sys, time, random
|
||||
|
||||
print('''
|
||||
#######################################################################################################################
|
||||
|
||||
This tokenizer is not used in any RWKV models yet. I plan to use it for the future multilang RWKV models.
|
||||
|
||||
Benefits:
|
||||
|
||||
* Good support of most languages, from European to CJK to Arabic and Hindi and more.
|
||||
|
||||
* Clean vocab. Good for code too. Vocab size = 65525 (use 0 for <|endoftext|>).
|
||||
|
||||
* Good at numbers: the numerical tokens are '0'~'9', '10'~'99', ' 0'~' 9', ' 10'~' 99'.
|
||||
|
||||
* Very easy tokenization:
|
||||
|
||||
** The input text must be in UTF-8.
|
||||
|
||||
** Greedy encoding: always pick the longest (in bytes) token (with the highest id) that matches your UTF-8 bytes.
|
||||
|
||||
* The tokenization result is surprisingly good, because the vocab respects word boundaries and UTF-8 boundaries.
|
||||
|
||||
For 10x faster speed:
|
||||
mypyc rwkv_tokenizer.py
|
||||
python3 -c "import rwkv_tokenizer"
|
||||
|
||||
#######################################################################################################################
|
||||
''')
|
||||
|
||||
########################################################################################################
|
||||
# Tokenizer #1 (reference, naive, slow)
|
||||
########################################################################################################
|
||||
|
||||
class RWKV_TOKENIZER():
|
||||
table = None # : list[list[list[bytes]]] = None
|
||||
good = None # : list[set[int]]
|
||||
wlen = None # : list[int]
|
||||
def __init__(self, file_name):
|
||||
self.vocab_size = 65525
|
||||
self.idx2token = {}
|
||||
sorted = [] # must be already sorted
|
||||
lines = open(file_name, "r", encoding="utf-8").readlines()
|
||||
for l in lines:
|
||||
idx = int(l[:l.index(' ')])
|
||||
x = eval(l[l.index(' '):l.rindex(' ')])
|
||||
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||
assert isinstance(x, bytes)
|
||||
assert len(x) == int(l[l.rindex(' '):])
|
||||
sorted += [x]
|
||||
self.idx2token[idx] = x
|
||||
|
||||
self.token2idx = {}
|
||||
for k, v in self.idx2token.items():
|
||||
self.token2idx[v] = int(k)
|
||||
|
||||
# precompute some tables for fast matching
|
||||
self.table = [[[] for j in range(256)] for i in range(256)]
|
||||
self.good = [set() for i in range(256)]
|
||||
self.wlen = [0 for i in range(256)]
|
||||
|
||||
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
|
||||
s = sorted[i]
|
||||
if len(s) >= 2:
|
||||
s0 = int(s[0])
|
||||
s1 = int(s[1])
|
||||
self.table[s0][s1] += [s]
|
||||
self.wlen[s0] = max(self.wlen[s0], len(s))
|
||||
self.good[s0].add(s1)
|
||||
|
||||
def encodeBytes(self, src: bytes):
|
||||
src_len: int = len(src)
|
||||
tokens = []
|
||||
i: int = 0
|
||||
while i < src_len:
|
||||
s: bytes = src[i : i + 1]
|
||||
|
||||
if i < src_len - 1:
|
||||
s1: int = int(src[i + 1])
|
||||
s0: int = int(src[i])
|
||||
if s1 in self.good[s0]:
|
||||
sss: bytes = src[i : i + self.wlen[s0]]
|
||||
try:
|
||||
s = next(filter(sss.startswith, self.table[s0][s1]))
|
||||
except:
|
||||
pass
|
||||
tokens.append(self.token2idx[s])
|
||||
i += len(s)
|
||||
|
||||
return tokens
|
||||
|
||||
def decodeBytes(self, tokens):
|
||||
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
||||
|
||||
def encode(self, src: str):
|
||||
return self.encodeBytes(src.encode("utf-8"))
|
||||
|
||||
def decode(self, tokens):
|
||||
return self.decodeBytes(tokens).decode('utf-8')
|
||||
|
||||
def token_to_id(self, token):
|
||||
return self.token2idx[token]
|
||||
|
||||
def get_vocab_size(self):
|
||||
return self.vocab_size
|
||||
|
||||
def get_vocab(self):
|
||||
return self.idx2token
|
||||
|
||||
def printTokens(self, tokens):
|
||||
for i in tokens:
|
||||
s = self.idx2token[i]
|
||||
try:
|
||||
s = s.decode('utf-8')
|
||||
except:
|
||||
pass
|
||||
print(f'{repr(s)}{i}', end=' ')
|
||||
# print(repr(s), i)
|
||||
print()
|
||||
|
||||
########################################################################################################
|
||||
# Tokenizer #2 (trie, faster) https://github.com/TkskKurumi/ChatRWKV-TRIE-Tokenizer
|
||||
########################################################################################################
|
||||
|
||||
class TRIE:
|
||||
__slots__ = tuple("ch,to,values,front".split(","))
|
||||
to:list
|
||||
values:set
|
||||
def __init__(self, front=None, ch=None):
|
||||
self.ch = ch
|
||||
self.to = [None for ch in range(256)]
|
||||
self.values = set()
|
||||
self.front = front
|
||||
|
||||
def __repr__(self):
|
||||
fr = self
|
||||
ret = []
|
||||
while(fr!=None):
|
||||
if(fr.ch!=None):
|
||||
ret.append(fr.ch)
|
||||
fr = fr.front
|
||||
return "<TRIE %s %s>"%(ret[::-1], self.values)
|
||||
|
||||
def add(self, key:bytes, idx:int=0, val=None):
|
||||
if(idx == len(key)):
|
||||
if(val is None):
|
||||
val = key
|
||||
self.values.add(val)
|
||||
return self
|
||||
ch = key[idx]
|
||||
if(self.to[ch] is None):
|
||||
self.to[ch] = TRIE(front=self, ch=ch)
|
||||
return self.to[ch].add(key, idx=idx+1, val=val)
|
||||
|
||||
def find_longest(self, key:bytes, idx:int=0):
|
||||
u:TRIE = self
|
||||
ch:int = key[idx]
|
||||
|
||||
while(u.to[ch] is not None):
|
||||
u = u.to[ch]
|
||||
idx += 1
|
||||
if(u.values):
|
||||
ret = idx, u, u.values
|
||||
if(idx==len(key)):
|
||||
break
|
||||
ch = key[idx]
|
||||
return ret
|
||||
|
||||
class TRIE_TOKENIZER():
|
||||
def __init__(self, file_name):
|
||||
self.vocab_size = 65525
|
||||
self.idx2token = {}
|
||||
sorted = [] # must be already sorted
|
||||
with open(file_name, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for l in lines:
|
||||
idx = int(l[:l.index(' ')])
|
||||
x = eval(l[l.index(' '):l.rindex(' ')])
|
||||
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||
assert isinstance(x, bytes)
|
||||
assert len(x) == int(l[l.rindex(' '):])
|
||||
sorted += [x]
|
||||
self.idx2token[idx] = x
|
||||
|
||||
self.token2idx = {}
|
||||
for k,v in self.idx2token.items():
|
||||
self.token2idx[v] = int(k)
|
||||
|
||||
self.root = TRIE()
|
||||
for t, i in self.token2idx.items():
|
||||
_ = self.root.add(t, val=(t, i))
|
||||
|
||||
def encodeBytes(self, src:bytes):
|
||||
idx:int = 0
|
||||
tokens = []
|
||||
while (idx < len(src)):
|
||||
_idx:int = idx
|
||||
idx, _, values = self.root.find_longest(src, idx)
|
||||
assert(idx != _idx)
|
||||
_, token = next(iter(values))
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def decodeBytes(self, tokens):
|
||||
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
||||
|
||||
def encode(self, src):
|
||||
return self.encodeBytes(src.encode("utf-8"))
|
||||
|
||||
def decode(self, tokens):
|
||||
return self.decodeBytes(tokens).decode('utf-8')
|
||||
|
||||
def get_vocab_size(self):
|
||||
return self.vocab_size
|
||||
|
||||
def get_vocab(self):
|
||||
return self.idx2token
|
||||
|
||||
def printTokens(self, tokens):
|
||||
for i in tokens:
|
||||
s = self.idx2token[i]
|
||||
try:
|
||||
s = s.decode('utf-8')
|
||||
except:
|
||||
pass
|
||||
print(f'{repr(s)}{i}', end=' ')
|
||||
print()
|
||||
205
finetune/json2binidx_tool/tools/tokenizer.py
vendored
Normal file
205
finetune/json2binidx_tool/tools/tokenizer.py
vendored
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) 2021, EleutherAI
|
||||
# This file is based on code by the authors denoted below and has been modified from its original version.
|
||||
#
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron tokenizers."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
from rwkv_tokenizer import RWKV_TOKENIZER, TRIE_TOKENIZER
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def build_tokenizer(args):
|
||||
"""Initialize tokenizer."""
|
||||
if args.rank == 0:
|
||||
print("> building {} tokenizer ...".format(args.tokenizer_type), flush=True)
|
||||
|
||||
# Select and instantiate the tokenizer.
|
||||
|
||||
if args.tokenizer_type.lower() == "HFTokenizer".lower():
|
||||
assert args.vocab_file is not None
|
||||
tokenizer = HFTokenizer(args.vocab_file)
|
||||
elif args.tokenizer_type.lower() == "RWKVTokenizer".lower():
|
||||
assert args.vocab_file is not None
|
||||
tokenizer = RWKVTokenizer(args.vocab_file)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"{} tokenizer is not " "implemented.".format(args.tokenizer_type)
|
||||
)
|
||||
|
||||
# Add vocab size.
|
||||
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _vocab_size_with_padding(orig_vocab_size, args):
|
||||
"""Pad vocab size so it is divisible by model parallel size and
|
||||
still having GPU friendly size."""
|
||||
|
||||
after = orig_vocab_size
|
||||
multiple = args.make_vocab_size_divisible_by * args.model_parallel_size
|
||||
while (after % multiple) != 0:
|
||||
after += 1
|
||||
if args.rank == 0:
|
||||
print(
|
||||
" > padded vocab (size: {}) with {} dummy tokens "
|
||||
"(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after),
|
||||
flush=True,
|
||||
)
|
||||
return after
|
||||
|
||||
|
||||
class AbstractTokenizer(ABC):
|
||||
"""Abstract class for tokenizer."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab_size(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab(self):
|
||||
"""Dictionary from vocab text token to id token."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def inv_vocab(self):
|
||||
"""Dictionary from vocab id token to text token."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text):
|
||||
pass
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
raise NotImplementedError(
|
||||
"detokenizer is not implemented for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
raise NotImplementedError(
|
||||
"CLS is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def sep(self):
|
||||
raise NotImplementedError(
|
||||
"SEP is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
raise NotImplementedError(
|
||||
"PAD is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
raise NotImplementedError(
|
||||
"EOD is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
raise NotImplementedError(
|
||||
"MASK is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
|
||||
class HFTokenizer(AbstractTokenizer):
|
||||
"""Designed to Integrate HF's Tokenizer library."""
|
||||
|
||||
def __init__(self, vocab_file):
|
||||
name = "HFTokenizer"
|
||||
super().__init__(name)
|
||||
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
self.eod_id = self.tokenizer.token_to_id("<|endoftext|>")
|
||||
self.pad_id = self.tokenizer.token_to_id("<|padding|>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.get_vocab_size()
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self.tokenizer.get_vocab()
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self.tokenizer.decoder
|
||||
|
||||
def tokenize(self, text: str):
|
||||
return self.tokenizer.encode(text).ids
|
||||
|
||||
def tokenize_batch(self, text_batch: Union[List[str], str]):
|
||||
return self.tokenizer.encode_batch(text_batch)
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
return self.eod_id
|
||||
|
||||
|
||||
class RWKVTokenizer(AbstractTokenizer):
|
||||
"""RWKV Worlds Tokenizer."""
|
||||
|
||||
def __init__(self, vocab_file='rwkv_vocab_v20230424.txt'):
|
||||
name = "RWKVTokenizer"
|
||||
super().__init__(name)
|
||||
|
||||
self.tokenizer = TRIE_TOKENIZER(vocab_file)
|
||||
self.eod_id = 0 # self.tokenizer.token_to_id("<|endoftext|>")
|
||||
# self.pad_id = self.tokenizer.token_to_id("<|padding|>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.get_vocab_size()
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self.tokenizer.get_vocab()
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self.tokenizer.decode
|
||||
|
||||
def tokenize(self, text: str):
|
||||
return self.tokenizer.encode(text)
|
||||
|
||||
def tokenize_batch(self, text_batch: Union[List[str], str]):
|
||||
return self.tokenizer.encode_batch(text_batch)
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
return self.eod_id
|
||||
133
finetune/lora/cuda/wkv_cuda.cu
vendored
Normal file
133
finetune/lora/cuda/wkv_cuda.cu
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
|
||||
#define MIN_VALUE (-1e38)
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||
F *__restrict__ const _y) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
F *__restrict__ const y = _y + _offset;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
F aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
|
||||
F ww = u + kk;
|
||||
F p = max(pp, ww);
|
||||
F e1 = exp(pp - p);
|
||||
F e2 = exp(ww - p);
|
||||
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||
const F *__restrict__ const _y, const F *__restrict__ const _gy,
|
||||
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
const F *__restrict__ const y = _y + _offset;
|
||||
const F *__restrict__ const gy = _gy + _offset;
|
||||
F *__restrict__ const gk = _gk + _offset;
|
||||
F *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
F q[Tmax], r[Tmax];
|
||||
|
||||
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
const F yy = y[ii];
|
||||
|
||||
F ww = u + kk;
|
||||
F p = max(pp, ww);
|
||||
F e1 = exp(pp - p);
|
||||
F e2 = exp(ww - p);
|
||||
const F qq = gy[ii] / (e1 * bb + e2);
|
||||
gw += (ga - gb * yy) * e1 * qq;
|
||||
gu += (vv - yy) * e2 * qq;
|
||||
q[i] = qq;
|
||||
r[i] = ww - p;
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
ga = e1 * (aa + ga);
|
||||
gb = e1 * (bb + gb);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
|
||||
_gu[_offsetBC] = gu;
|
||||
|
||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
const F yy = y[ii];
|
||||
const F qq = q[i];
|
||||
const F rr = r[i];
|
||||
|
||||
F e1 = qq * exp(rr);
|
||||
F e2 = exp(kk + pp);
|
||||
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
|
||||
gv[ii] = e1 + e2 * aa;
|
||||
|
||||
const F ww = w + pp;
|
||||
const F www = rr - u - kk;
|
||||
const F p = max(ww, www);
|
||||
e1 = exp(ww - p);
|
||||
e2 = qq * exp(www - p);
|
||||
aa = e1 * aa + e2;
|
||||
bb = e1 * bb - e2 * yy;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||
}
|
||||
132
finetune/lora/cuda/wkv_cuda_bf16.cu
vendored
Normal file
132
finetune/lora/cuda/wkv_cuda_bf16.cu
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
#include "ATen/ATen.h"
|
||||
#define MIN_VALUE (-1e38)
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
||||
bf16 *__restrict__ const _y) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
float u = float(_u[_c]);
|
||||
float w = _w[_c];
|
||||
const bf16 *__restrict__ const k = _k + _offset;
|
||||
const bf16 *__restrict__ const v = _v + _offset;
|
||||
bf16 *__restrict__ const y = _y + _offset;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
float aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
||||
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
|
||||
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
float u = float(_u[_c]);
|
||||
float w = _w[_c];
|
||||
const bf16 *__restrict__ const k = _k + _offset;
|
||||
const bf16 *__restrict__ const v = _v + _offset;
|
||||
const bf16 *__restrict__ const y = _y + _offset;
|
||||
const bf16 *__restrict__ const gy = _gy + _offset;
|
||||
bf16 *__restrict__ const gk = _gk + _offset;
|
||||
bf16 *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
float q[Tmax], r[Tmax];
|
||||
|
||||
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
const float yy = float(y[ii]);
|
||||
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
const float qq = float(gy[ii]) / (e1 * bb + e2);
|
||||
gw += (ga - gb * yy) * e1 * qq;
|
||||
gu += (vv - yy) * e2 * qq;
|
||||
q[i] = qq;
|
||||
r[i] = ww - p;
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
ga = e1 * (aa + ga);
|
||||
gb = e1 * (bb + gb);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
|
||||
_gu[_offsetBC] = bf16(gu);
|
||||
|
||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
const float yy = float(y[ii]);
|
||||
const float qq = q[i];
|
||||
const float rr = r[i];
|
||||
|
||||
float e1 = qq * exp(rr);
|
||||
float e2 = exp(kk + pp);
|
||||
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
|
||||
gv[ii] = bf16(e1 + e2 * aa);
|
||||
|
||||
const float ww = w + pp;
|
||||
const float www = rr - u - kk;
|
||||
const float p = max(ww, www);
|
||||
e1 = exp(ww - p);
|
||||
e2 = qq * exp(www - p);
|
||||
aa = e1 * aa + e2;
|
||||
bb = e1 * bb - e2 * yy;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||
}
|
||||
21
finetune/lora/cuda/wkv_op.cpp
vendored
Normal file
21
finetune/lora/cuda/wkv_op.cpp
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
|
||||
|
||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
||||
}
|
||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv forward");
|
||||
m.def("backward", &backward, "wkv backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
||||
25
finetune/lora/cuda/wkv_op_bf16.cpp
vendored
Normal file
25
finetune/lora/cuda/wkv_op_bf16.cpp
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/ATen.h"
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
|
||||
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
|
||||
|
||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||
}
|
||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
||||
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
|
||||
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv forward");
|
||||
m.def("backward", &backward, "wkv backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
||||
53
finetune/lora/merge_lora.py
vendored
Normal file
53
finetune/lora/merge_lora.py
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
import typing
|
||||
import torch
|
||||
|
||||
if '-h' in sys.argv or '--help' in sys.argv:
|
||||
print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>')
|
||||
|
||||
if sys.argv[1] == '--use-gpu':
|
||||
device = 'cuda'
|
||||
lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5]
|
||||
else:
|
||||
device = 'cpu'
|
||||
lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4]
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
|
||||
# merge LoRA-only slim checkpoint into the main weights
|
||||
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
||||
for k in w_lora.keys():
|
||||
w[k] = w_lora[k]
|
||||
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
||||
# merge LoRA weights
|
||||
keys = list(w.keys())
|
||||
for k in keys:
|
||||
if k.endswith('.weight'):
|
||||
prefix = k[:-len('.weight')]
|
||||
lora_A = prefix + '.lora_A'
|
||||
lora_B = prefix + '.lora_B'
|
||||
if lora_A in keys:
|
||||
assert lora_B in keys
|
||||
print(f'merging {lora_A} and {lora_B} into {k}')
|
||||
assert w[lora_B].shape[1] == w[lora_A].shape[0]
|
||||
lora_r = w[lora_B].shape[1]
|
||||
w[k] = w[k].to(device=device)
|
||||
w[lora_A] = w[lora_A].to(device=device)
|
||||
w[lora_B] = w[lora_B].to(device=device)
|
||||
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
|
||||
output_w[k] = w[k].to(device='cpu', copy=True)
|
||||
del w[k]
|
||||
del w[lora_A]
|
||||
del w[lora_B]
|
||||
continue
|
||||
|
||||
if 'lora' not in k:
|
||||
print(f'retaining {k}')
|
||||
output_w[k] = w[k].clone()
|
||||
del w[k]
|
||||
|
||||
torch.save(output_w, output)
|
||||
0
finetune/lora/src/__init__.py
vendored
Normal file
0
finetune/lora/src/__init__.py
vendored
Normal file
269
finetune/lora/src/binidx.py
vendored
Normal file
269
finetune/lora/src/binidx.py
vendored
Normal file
@@ -0,0 +1,269 @@
|
||||
from lib2to3.pgen2 import token
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import shutil
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
|
||||
def print_rank_0(*message):
|
||||
pass
|
||||
# """If distributed is initialized print only on rank 0."""
|
||||
# if torch.distributed.is_initialized():
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(*message, flush=True)
|
||||
# else:
|
||||
# print(*message, flush=True)
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
pass
|
||||
# with open(path, "rb") as stream:
|
||||
# while stream.read(100 * 1024 * 1024):
|
||||
# pass
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: float,
|
||||
7: np.double,
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
# Write Magic string so we can check the file format then opening it again.
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
# Write version number
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
# Little endian unsigned 8 Bit integer
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
dtype_size = dtype().itemsize
|
||||
address = 0
|
||||
pointers = []
|
||||
|
||||
for size in sizes:
|
||||
pointers.append(address)
|
||||
address += size * dtype_size
|
||||
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
# Little endian unsigned 64 Bit integer
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
# Little endian unsigned 8 Bit integer
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
print_rank_0(" reading sizes...")
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
print_rank_0(" reading pointers...")
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
print_rank_0(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print_rank_0(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(
|
||||
data_file_path(self._path), mode="r", order="C"
|
||||
)
|
||||
print_rank_0(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError(
|
||||
"Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||
)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
224
finetune/lora/src/dataset.py
vendored
Normal file
224
finetune/lora/src/dataset.py
vendored
Normal file
@@ -0,0 +1,224 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import json, math, random, os, sys
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from .binidx import MMapIndexedDataset
|
||||
from .utils import MaybeIsPrime
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
if args.data_type == "binidx":
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
|
||||
|
||||
if args.data_file.endswith('/'):
|
||||
d_all = []
|
||||
for p in os.listdir(args.data_file):
|
||||
if p.endswith(".idx"):
|
||||
d_all += [p[:-4]]
|
||||
d_all.sort()
|
||||
rank_zero_info(d_all)
|
||||
exit(0)
|
||||
else:
|
||||
self.data = MMapIndexedDataset(args.data_file)
|
||||
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
|
||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||
|
||||
if args.my_qa_mask > 0:
|
||||
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
|
||||
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
||||
|
||||
if args.my_pile_stage > 0:
|
||||
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
||||
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||
assert self.samples_per_epoch == 40320
|
||||
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
|
||||
dataset_slot = self.data_size // args.ctx_len
|
||||
if args.my_pile_stage != 4:
|
||||
assert MaybeIsPrime(args.magic_prime)
|
||||
assert args.magic_prime % 3 == 2
|
||||
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
|
||||
elif args.data_type == "numpy":
|
||||
self.data = np.load(args.data_file).astype("int")
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
||||
self.data_size = len(self.data)
|
||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||
elif args.data_type == "uint16":
|
||||
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
||||
self.data_size = self.data.shape[0]
|
||||
rank_zero_info(f"Data has {self.data_size} samples.")
|
||||
elif args.data_type == "wds_img":
|
||||
self.vocab_size = -1
|
||||
self.data_size = -1
|
||||
self.data = None
|
||||
self.error_count = 0
|
||||
else:
|
||||
if args.data_type == "dummy":
|
||||
rank_zero_info("Building dummy data...")
|
||||
self.data = ""
|
||||
for i in range(100000):
|
||||
aa = (i) % 10000
|
||||
bb = (i * i) % 10000
|
||||
cc = aa + bb
|
||||
self.data += f".{aa}+{bb}={cc}."
|
||||
else:
|
||||
self.data = open(args.data_file, "r", encoding=args.data_type).read()
|
||||
rank_zero_info("Building token list...")
|
||||
unique = sorted(list(set(self.data)))
|
||||
self.vocab_size = len(unique)
|
||||
# rank_zero_info()
|
||||
# for u in unique:
|
||||
# print(u, end=' ')
|
||||
# rank_zero_info('\n\n')
|
||||
xx = 0
|
||||
xxObj = {}
|
||||
for u in unique:
|
||||
xxObj[xx] = u
|
||||
xx += 1
|
||||
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
|
||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||
self.data_size = len(self.data)
|
||||
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
|
||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||
|
||||
def __len__(self):
|
||||
return self.args.epoch_steps * self.args.micro_bsz
|
||||
|
||||
def __getitem__(self, idx):
|
||||
args = self.args
|
||||
rank = self.global_rank
|
||||
epoch = self.real_epoch
|
||||
world_size = self.world_size
|
||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
||||
|
||||
if args.data_type == "wds_img":
|
||||
def init_wds(self, bias=0):
|
||||
def identity(x):
|
||||
return x
|
||||
import webdataset as wds
|
||||
import torchvision.transforms as transforms
|
||||
# img_transform = transforms.Compose(
|
||||
# [transforms.CenterCrop(256)]
|
||||
# )
|
||||
img_transform = transforms.Compose([
|
||||
transforms.CenterCrop(512),
|
||||
transforms.Resize((args.my_img_size))
|
||||
])
|
||||
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
|
||||
for pp in self.data_raw.pipeline:
|
||||
if 'Resampled' in str(pp):
|
||||
pp.deterministic = True
|
||||
def worker_seed():
|
||||
return rank*100000+epoch+bias*1e9
|
||||
pp.worker_seed = worker_seed
|
||||
self.data = iter(self.data_raw)
|
||||
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
|
||||
if self.data == None:
|
||||
init_wds(self)
|
||||
trial = 0
|
||||
while trial < 10:
|
||||
try:
|
||||
dd = next(self.data) # jpg, json, txt
|
||||
break
|
||||
except:
|
||||
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
|
||||
self.error_count += 1
|
||||
init_wds(self, self.error_count)
|
||||
trial += 1
|
||||
pass
|
||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
|
||||
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
|
||||
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")
|
||||
return dd[0], dd[2]
|
||||
else:
|
||||
if args.data_type == "uint16":
|
||||
i = np.random.randint(0, self.data_size-1)
|
||||
dix = self.data[i]
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||
else:
|
||||
ctx_len = args.ctx_len
|
||||
req_len = ctx_len + 1
|
||||
magic_prime = args.magic_prime
|
||||
data = self.data
|
||||
|
||||
if args.my_pile_stage > 0 and args.my_pile_stage != 4:
|
||||
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
|
||||
|
||||
if args.my_qa_mask > 0:
|
||||
ii_orig = ii
|
||||
if ii % 2 == 0:
|
||||
ii = (ii // 2) * args.magic_prime
|
||||
if args.ctx_len == 1024:
|
||||
magic_prime = 324331313
|
||||
elif args.ctx_len == 2048:
|
||||
magic_prime = 162165671
|
||||
elif args.ctx_len == 4096:
|
||||
magic_prime = 81082817
|
||||
data = self.data_pile
|
||||
else:
|
||||
ii = ii // 2
|
||||
|
||||
factor = (math.sqrt(5) - 1) / 2
|
||||
factor = int(magic_prime * factor)
|
||||
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
|
||||
if (args.my_qa_mask == 0) or (data == self.data_pile):
|
||||
i = i + args.my_pile_shift
|
||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
|
||||
else:
|
||||
# cheat: pick a random spot in dataset
|
||||
i = np.random.randint(0, self.data_size - req_len)
|
||||
|
||||
if args.data_type == "binidx":
|
||||
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
||||
elif args.data_type == "numpy":
|
||||
dix = data[i : i + req_len]
|
||||
else:
|
||||
dix = [self.stoi[s] for s in data[i : i + req_len]]
|
||||
|
||||
if args.my_qa_mask == 1:
|
||||
if data == self.data_pile:
|
||||
z = [1] * ctx_len
|
||||
else:
|
||||
z = [0] * ctx_len
|
||||
z_sum = 0
|
||||
isGood = False
|
||||
for i in range(3, ctx_len):
|
||||
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
|
||||
isGood = True
|
||||
if dix[i] == 0:
|
||||
isGood = False
|
||||
if isGood:
|
||||
z[i] = 1
|
||||
z_sum += 1
|
||||
if z_sum == 0:
|
||||
z = [1] * ctx_len
|
||||
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
|
||||
z = torch.tensor(z, dtype=torch.bfloat16)
|
||||
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||
|
||||
# if ii_orig < 50:
|
||||
# # if rank == 1:
|
||||
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
|
||||
# else:
|
||||
# exit(0)
|
||||
|
||||
if args.my_qa_mask == 1:
|
||||
return x, y, z
|
||||
|
||||
return x, y
|
||||
678
finetune/lora/src/model.py
vendored
Normal file
678
finetune/lora/src/model.py
vendored
Normal file
@@ -0,0 +1,678 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import functools
|
||||
import os, math, gc, importlib
|
||||
import torch
|
||||
# torch._C._jit_set_profiling_executor(True)
|
||||
# torch._C._jit_set_profiling_mode(True)
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
from torch.nn import functional as F
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||
if importlib.util.find_spec('deepspeed'):
|
||||
import deepspeed
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
||||
|
||||
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
|
||||
|
||||
LORA_CONFIG = {
|
||||
"r": 0,
|
||||
"alpha": 0,
|
||||
"dropout": 0,
|
||||
"parts": {"att", "ln", "time"},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
|
||||
except:
|
||||
os.environ["RWKV_MY_TESTING"] = ''
|
||||
|
||||
def __nop(ob):
|
||||
return ob
|
||||
|
||||
|
||||
MyModule = nn.Module
|
||||
MyFunction = __nop
|
||||
if os.environ["RWKV_JIT_ON"] == "1":
|
||||
MyModule = torch.jit.ScriptModule
|
||||
MyFunction = torch.jit.script_method
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# CUDA Kernel
|
||||
########################################################################################################
|
||||
|
||||
T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
|
||||
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["finetune/lora/cuda/wkv_op_bf16.cpp", "finetune/lora/cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
||||
class WKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, w, u, k, v):
|
||||
ctx.B = B
|
||||
ctx.T = T
|
||||
ctx.C = C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
w = -torch.exp(w.float().contiguous())
|
||||
u = u.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||
ctx.save_for_backward(w, u, k, v, y)
|
||||
return y
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
B = ctx.B
|
||||
T = ctx.T
|
||||
C = ctx.C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
w, u, k, v, y = ctx.saved_tensors
|
||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
||||
gw = torch.sum(gw, dim=0)
|
||||
gu = torch.sum(gu, dim=0)
|
||||
return (None, None, None, gw, gu, gk, gv)
|
||||
else:
|
||||
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["finetune/lora/cuda/wkv_op.cpp", "finetune/lora/cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
||||
class WKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, w, u, k, v):
|
||||
ctx.B = B
|
||||
ctx.T = T
|
||||
ctx.C = C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
w = -torch.exp(w.contiguous())
|
||||
u = u.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
else:
|
||||
w = -torch.exp(w.float().contiguous())
|
||||
u = u.float().contiguous()
|
||||
k = k.float().contiguous()
|
||||
v = v.float().contiguous()
|
||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
|
||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||
ctx.save_for_backward(w, u, k, v, y)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
return y
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
return y.half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
return y.bfloat16()
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
B = ctx.B
|
||||
T = ctx.T
|
||||
C = ctx.C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
w, u, k, v, y = ctx.saved_tensors
|
||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
||||
else:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
|
||||
gw = torch.sum(gw, dim=0)
|
||||
gu = torch.sum(gu, dim=0)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
return (None, None, None, gw, gu, gk, gv)
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
||||
|
||||
|
||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
||||
return WKV.apply(B, T, C, w, u, k, v)
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# LoRA
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class LoraLinear(nn.Module):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||
assert bias == False, "Biased LoraLinear not supported"
|
||||
|
||||
r, alpha, dropout = LORA_CONFIG["r"], LORA_CONFIG[
|
||||
"alpha"], LORA_CONFIG["dropout"]
|
||||
self.lora_A = nn.Parameter(torch.empty(r, in_features))
|
||||
self.lora_B = nn.Parameter(torch.empty(out_features, r))
|
||||
self.lora_dropout = nn.Dropout(dropout)
|
||||
self.scaling = alpha / r
|
||||
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
F.linear(x, self.weight) + self.scaling *
|
||||
F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B))
|
||||
|
||||
|
||||
@functools.wraps(LoraLinear)
|
||||
def make_linear_att(*args, **kwargs):
|
||||
if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||||
return LoraLinear(*args, **kwargs)
|
||||
else:
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
@functools.wraps(LoraLinear)
|
||||
def make_linear_ffn(*args, **kwargs):
|
||||
if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||||
return LoraLinear(*args, **kwargs)
|
||||
else:
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class RWKV_TimeMix(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
self.ctx_len = args.ctx_len
|
||||
self.n_embd = args.n_embd
|
||||
|
||||
with torch.no_grad(): # fancy init
|
||||
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||||
ddd = torch.ones(1, 1, args.n_embd)
|
||||
for i in range(args.n_embd):
|
||||
ddd[0, 0, i] = i / args.n_embd
|
||||
|
||||
# fancy time_decay
|
||||
decay_speed = torch.ones(args.dim_att)
|
||||
for h in range(args.dim_att):
|
||||
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
||||
self.time_decay = nn.Parameter(decay_speed)
|
||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||||
|
||||
# fancy time_first
|
||||
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
|
||||
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
|
||||
|
||||
# fancy time_mix
|
||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
|
||||
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
||||
|
||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
||||
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
d_qkv = args.n_embd // 16
|
||||
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
|
||||
with torch.no_grad():
|
||||
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||
|
||||
if 'a' not in os.environ["RWKV_MY_TESTING"]:
|
||||
@MyFunction
|
||||
def jit_func(self, x):
|
||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
sr = torch.sigmoid(r)
|
||||
return sr, k, v
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||
sr, k, v = self.jit_func(x)
|
||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
||||
return self.output(rwkv)
|
||||
|
||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
||||
@MyFunction
|
||||
def QKV(self, q, k, v):
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.att_mask == 0, float('-inf'))
|
||||
att = F.softmax(att, dim = -1)
|
||||
x = att @ v
|
||||
return x
|
||||
|
||||
@MyFunction
|
||||
def jit_funcQKV(self, x):
|
||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
|
||||
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
|
||||
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
sr = torch.sigmoid(r)
|
||||
qq = self.qq(xqq)
|
||||
kk = self.kk(xkk)
|
||||
vv = self.vv(xvv)
|
||||
return sr, k, v, qq, kk, vv
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
|
||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
||||
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
|
||||
return rwkv
|
||||
|
||||
########################################################################################################
|
||||
|
||||
class RWKV_ChannelMix(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad(): # fancy init of time_mix
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||||
ddd = torch.ones(1, 1, args.n_embd)
|
||||
for i in range(args.n_embd):
|
||||
ddd[0, 0, i] = i / args.n_embd
|
||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
|
||||
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
|
||||
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
|
||||
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
|
||||
|
||||
@MyFunction
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
k = self.key(xk)
|
||||
k = torch.square(torch.relu(k))
|
||||
kv = self.value(k)
|
||||
return torch.sigmoid(self.receptance(xr)) * kv
|
||||
|
||||
class MishGLU(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad():
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
|
||||
|
||||
x = torch.ones(1, 1, args.n_embd)
|
||||
for i in range(args.n_embd):
|
||||
x[0, 0, i] = i / args.n_embd
|
||||
|
||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||||
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||||
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
||||
|
||||
@MyFunction
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x)
|
||||
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
a = self.aa(xa)
|
||||
b = self.bb(xb)
|
||||
return self.value(a * F.mish(b))
|
||||
|
||||
########################################################################################################
|
||||
# The RWKV Model with our blocks
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.ln1 = nn.LayerNorm(args.n_embd)
|
||||
self.ln2 = nn.LayerNorm(args.n_embd)
|
||||
|
||||
if self.layer_id == 0:
|
||||
self.ln0 = nn.LayerNorm(args.n_embd)
|
||||
if args.my_pos_emb > 0:
|
||||
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
||||
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
||||
|
||||
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
||||
self.ffnPre = RWKV_ChannelMix(args, 0)
|
||||
else:
|
||||
self.att = RWKV_TimeMix(args, layer_id)
|
||||
|
||||
if 'g' in os.environ["RWKV_MY_TESTING"]:
|
||||
self.ffn = MishGLU(args, layer_id)
|
||||
else:
|
||||
self.ffn = RWKV_ChannelMix(args, layer_id)
|
||||
|
||||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||||
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
||||
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
||||
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
|
||||
def forward(self, x, x_emb=None):
|
||||
args = self.args
|
||||
B, T, C = x.size()
|
||||
if self.layer_id == 0:
|
||||
x = self.ln0(x)
|
||||
if args.my_pos_emb > 0:
|
||||
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
|
||||
x = x + pos_emb
|
||||
|
||||
if self.layer_id == 0 and args.pre_ffn > 0:
|
||||
x = x + self.ffnPre(self.ln1(x))
|
||||
else:
|
||||
x = x + self.att(self.ln1(x))
|
||||
x = x + self.ffn(self.ln2(x))
|
||||
|
||||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||||
xx = self.tiny_ln(x)
|
||||
q = self.tiny_q(xx)[:, :T, :]
|
||||
k = self.tiny_k(xx)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
||||
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
||||
x = x + c @ self.tiny_v(x_emb)
|
||||
return x
|
||||
|
||||
|
||||
class L2Wrap(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, loss, y):
|
||||
ctx.save_for_backward(y)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
y = ctx.saved_tensors[0]
|
||||
# to encourage the logits to be close to 0
|
||||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||
gy = torch.zeros_like(y)
|
||||
gy.scatter_(-1, ids, maxx * factor)
|
||||
return (grad_output, gy)
|
||||
|
||||
|
||||
class RWKV(pl.LightningModule):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
if not hasattr(args, 'dim_att'):
|
||||
args.dim_att = args.n_embd
|
||||
if not hasattr(args, 'dim_ffn'):
|
||||
args.dim_ffn = args.n_embd * 4
|
||||
if not hasattr(args, 'tiny_att_layer'):
|
||||
args.tiny_att_layer = -1
|
||||
if not hasattr(args, 'tiny_att_dim'):
|
||||
args.tiny_att_dim = -1
|
||||
|
||||
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
||||
|
||||
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
||||
|
||||
self.ln_out = nn.LayerNorm(args.n_embd)
|
||||
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||
|
||||
if args.head_qk > 0:
|
||||
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
|
||||
def configure_optimizers(self):
|
||||
args = self.args
|
||||
if args.layerwise_lr > 0:
|
||||
lr_1x = set()
|
||||
lr_2x = set()
|
||||
lr_3x = set()
|
||||
for n, p in self.named_parameters():
|
||||
if "time_mix" in n:
|
||||
if args.my_pile_stage == 2:
|
||||
lr_2x.add(n)
|
||||
else:
|
||||
lr_1x.add(n)
|
||||
elif "time_decay" in n:
|
||||
if args.my_pile_stage == 2:
|
||||
lr_3x.add(n)
|
||||
else:
|
||||
lr_2x.add(n)
|
||||
elif "time_first" in n:
|
||||
lr_3x.add(n)
|
||||
else:
|
||||
lr_1x.add(n)
|
||||
lr_1x = sorted(list(lr_1x))
|
||||
lr_2x = sorted(list(lr_2x))
|
||||
lr_3x = sorted(list(lr_3x))
|
||||
# print('1x', lr_1x)
|
||||
# print('2x', lr_2x)
|
||||
# print('3x', lr_3x)
|
||||
param_dict = {n: p for n, p in self.named_parameters()}
|
||||
if args.my_pile_stage == 2:
|
||||
optim_groups = [
|
||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
|
||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
|
||||
]
|
||||
else:
|
||||
optim_groups = [
|
||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
||||
]
|
||||
else:
|
||||
optim_groups = [
|
||||
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
for g in optim_groups:
|
||||
g["params"] = [p for p in g["params"] if p.requires_grad]
|
||||
optim_groups = [g for g in optim_groups if len(g["params"]) > 0]
|
||||
|
||||
if self.deepspeed_offload:
|
||||
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
|
||||
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
||||
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
||||
|
||||
@property
|
||||
def deepspeed_offload(self) -> bool:
|
||||
strategy = self.trainer.strategy
|
||||
if isinstance(strategy, DeepSpeedStrategy):
|
||||
cfg = strategy.config["zero_optimization"]
|
||||
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
||||
return False
|
||||
|
||||
def forward(self, idx):
|
||||
args = self.args
|
||||
B, T = idx.size()
|
||||
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
||||
|
||||
x = self.emb(idx)
|
||||
x_emb = x
|
||||
|
||||
if args.tiny_att_dim > 0:
|
||||
for block in self.blocks:
|
||||
if args.grad_cp == 1:
|
||||
if args.lora:
|
||||
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||||
else:
|
||||
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
||||
else:
|
||||
x = block(x, x_emb)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
if args.grad_cp == 1:
|
||||
if args.lora:
|
||||
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||||
else:
|
||||
x = deepspeed.checkpointing.checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_out(x)
|
||||
|
||||
if args.head_qk > 0:
|
||||
q = self.head_q(x)[:, :T, :]
|
||||
k = self.head_k(x)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
||||
|
||||
x = self.head(x) + c
|
||||
else:
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
args = self.args
|
||||
if args.my_qa_mask != 1:
|
||||
idx, targets = batch
|
||||
logits = self(idx)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
else:
|
||||
idx, targets, mask = batch
|
||||
mask = mask.view(-1)
|
||||
sum_mask = torch.sum(mask).item()
|
||||
# if sum_mask == 0:
|
||||
# return torch.tensor([0.0], requires_grad=True)
|
||||
|
||||
logits = self(idx)
|
||||
if sum_mask == mask.shape[0]:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
# print('rank', self.global_rank, 'loss', loss.item())
|
||||
else:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
||||
# loss_raw = loss
|
||||
loss = torch.sum(loss * mask) / sum_mask
|
||||
|
||||
# torch.set_printoptions(threshold=10000)
|
||||
# if True: #self.global_rank == 1:
|
||||
# tmp = ''
|
||||
# sss = 0
|
||||
# ccc = 0
|
||||
# for i in range(mask.shape[0]):
|
||||
# if mask[i] > 0:
|
||||
# tmp += str(idx.view(-1)[i].item()) + ','
|
||||
# sss += loss_raw.view(-1)[i].float().item()
|
||||
# ccc += 1
|
||||
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
||||
|
||||
return L2Wrap.apply(loss, logits)
|
||||
|
||||
def training_step_end(self, batch_parts):
|
||||
all = self.all_gather(batch_parts)
|
||||
if self.trainer.is_global_zero:
|
||||
self.trainer.my_loss_all = all
|
||||
|
||||
def generate_init_weight(self):
|
||||
print(
|
||||
f"""
|
||||
############################################################################
|
||||
#
|
||||
# Init model weight (slow for large models)...
|
||||
#
|
||||
############################################################################
|
||||
"""
|
||||
)
|
||||
m = {}
|
||||
for n in self.state_dict():
|
||||
p = self.state_dict()[n]
|
||||
shape = p.shape
|
||||
|
||||
gain = 1.0
|
||||
scale = 1.0
|
||||
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
|
||||
m[n] = p
|
||||
else:
|
||||
if n == "emb.weight":
|
||||
scale = -1 * self.args.lr_init
|
||||
else:
|
||||
if shape[0] > shape[1]:
|
||||
gain = math.sqrt(shape[0] / shape[1])
|
||||
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
|
||||
if kk in n:
|
||||
scale = 0
|
||||
if n == "head.weight":
|
||||
scale = 0.5
|
||||
if "head_k." in n:
|
||||
scale = 0.1
|
||||
if "head_q." in n:
|
||||
scale = 0
|
||||
|
||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
|
||||
|
||||
if self.args.accelerator.upper() == "GPU":
|
||||
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
||||
else:
|
||||
m[n] = torch.empty((shape[0], shape[1]))
|
||||
|
||||
if scale == 0:
|
||||
nn.init.zeros_(m[n])
|
||||
elif scale < 0:
|
||||
nn.init.uniform_(m[n], a=scale, b=-scale)
|
||||
else:
|
||||
nn.init.orthogonal_(m[n], gain=gain * scale)
|
||||
|
||||
m[n] = m[n].cpu()
|
||||
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
m[n] = m[n].half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
m[n] = m[n].bfloat16()
|
||||
|
||||
# if n == "emb.weight":
|
||||
# print(m[n])
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return m
|
||||
203
finetune/lora/src/trainer.py
vendored
Normal file
203
finetune/lora/src/trainer.py
vendored
Normal file
@@ -0,0 +1,203 @@
|
||||
import os, math, time, datetime, subprocess
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from .model import LORA_CONFIG
|
||||
|
||||
def my_save(dd, ff):
|
||||
if '14b-run1' not in ff:
|
||||
torch.save(dd, ff)
|
||||
else:
|
||||
fn = ff.split('/')[-1]
|
||||
fff = '/dev/shm/' + fn
|
||||
torch.save(dd, fff)
|
||||
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
||||
|
||||
class train_callback(pl.Callback):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
args = self.args
|
||||
# if args.cuda_cleanup > 0:
|
||||
# torch.cuda.empty_cache()
|
||||
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||
|
||||
# LR schedule
|
||||
w_step = args.warmup_steps
|
||||
if args.lr_final == args.lr_init or args.epoch_count == 0:
|
||||
lr = args.lr_init
|
||||
else:
|
||||
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
|
||||
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
|
||||
progress = (decay_step - w_step + 1) / (decay_total - w_step)
|
||||
progress = min(1, max(0, progress))
|
||||
|
||||
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
||||
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
||||
else: # exp decay
|
||||
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
||||
|
||||
if trainer.global_step < w_step:
|
||||
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
||||
# if trainer.is_global_zero:
|
||||
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
|
||||
|
||||
for param_group in trainer.optimizers[0].param_groups:
|
||||
if args.layerwise_lr > 0:
|
||||
param_group["lr"] = lr * param_group["my_lr_scale"]
|
||||
# print(param_group["lr"], param_group["my_lr_scale"])
|
||||
else:
|
||||
param_group["lr"] = lr
|
||||
|
||||
trainer.my_lr = lr
|
||||
# rank_zero_info(f"{real_step} {lr}")
|
||||
|
||||
if trainer.global_step == 0:
|
||||
if trainer.is_global_zero: # logging
|
||||
trainer.my_loss_sum = 0
|
||||
trainer.my_loss_count = 0
|
||||
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
||||
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
|
||||
try:
|
||||
print(f"\n{trainer.strategy.config}\n")
|
||||
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
||||
except:
|
||||
pass
|
||||
trainer.my_log.flush()
|
||||
if len(args.wandb) > 0:
|
||||
print("Login to wandb...")
|
||||
import wandb
|
||||
wandb.init(
|
||||
project=args.wandb,
|
||||
name=args.run_name + " " + args.my_timestamp,
|
||||
config=args,
|
||||
save_code=False,
|
||||
)
|
||||
trainer.my_wandb = wandb
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
args = self.args
|
||||
if trainer.is_global_zero: # logging
|
||||
t_now = time.time_ns()
|
||||
token_per_step = args.ctx_len * args.real_bsz
|
||||
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||
kt_s = 0
|
||||
try:
|
||||
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
||||
kt_s = token_per_step / t_cost / 1000
|
||||
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
||||
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
|
||||
except:
|
||||
pass
|
||||
trainer.my_time_ns = t_now
|
||||
trainer.my_loss = trainer.my_loss_all.float().mean().item()
|
||||
trainer.my_loss_sum += trainer.my_loss
|
||||
trainer.my_loss_count += 1
|
||||
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
|
||||
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
|
||||
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
||||
# self.log("s", real_step, prog_bar=True, on_step=True)
|
||||
|
||||
if len(args.wandb) > 0:
|
||||
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
|
||||
if kt_s > 0:
|
||||
lll["kt/s"] = kt_s
|
||||
trainer.my_wandb.log(lll, step=int(real_step))
|
||||
if args.magic_prime > 0:
|
||||
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
||||
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
|
||||
to_save_dict = pl_module.state_dict()
|
||||
my_save(
|
||||
to_save_dict,
|
||||
f"{args.proj_dir}/rwkv-final.pth",
|
||||
)
|
||||
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
args = self.args
|
||||
dataset = trainer.train_dataloader.dataset.datasets
|
||||
assert "MyDataset" in str(dataset)
|
||||
dataset.global_rank = trainer.global_rank
|
||||
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
|
||||
dataset.world_size = trainer.world_size
|
||||
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
args = self.args
|
||||
if trainer.is_global_zero: # logging & save state_dict
|
||||
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
|
||||
if args.data_type == 'wds_img':
|
||||
raw_dict = pl_module.state_dict()
|
||||
to_save_dict = {}
|
||||
for k in raw_dict:
|
||||
if k.startswith('encoder.') or k.startswith('decoder.'):
|
||||
to_save_dict[k] = raw_dict[k]
|
||||
else:
|
||||
to_save_dict = pl_module.state_dict()
|
||||
|
||||
if args.lora:
|
||||
enable_time_finetune = 'time' in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
|
||||
lora_dict = {}
|
||||
for name, state in to_save_dict.items():
|
||||
if ('.lora_' in name
|
||||
or (enable_time_finetune and '.time_' in name)
|
||||
or (enable_ln_finetune and '.ln' in name)):
|
||||
lora_dict[name] = state
|
||||
to_save_dict = lora_dict
|
||||
|
||||
try:
|
||||
my_save(
|
||||
to_save_dict,
|
||||
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
||||
)
|
||||
except Exception as e:
|
||||
print('Error\n\n', e, '\n\n')
|
||||
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
|
||||
trainer.my_log.flush()
|
||||
|
||||
trainer.my_loss_sum = 0
|
||||
trainer.my_loss_count = 0
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def generate_init_weight(model, init_weight_name):
|
||||
mm = model.generate_init_weight()
|
||||
|
||||
if model.args.my_pile_stage == 1:
|
||||
if len(model.args.load_model) > 0:
|
||||
print(f"Combine weights from {model.args.load_model}...")
|
||||
load_dict = torch.load(model.args.load_model, map_location="cpu")
|
||||
for k in load_dict:
|
||||
assert k in mm
|
||||
src = load_dict[k]
|
||||
try:
|
||||
mm[k] = src.reshape(mm[k].shape)
|
||||
except:
|
||||
tmp = mm[k].squeeze().clone()
|
||||
print(k, src.shape, '-->', mm[k].shape)
|
||||
ss = src.shape[0]
|
||||
dd = tmp.shape[0]
|
||||
for i in range(dd):
|
||||
pos = i / dd * ss
|
||||
if pos >= ss - 1:
|
||||
tmp[i] = src[ss-1]
|
||||
else:
|
||||
p0 = int(math.floor(pos))
|
||||
ii = pos - p0
|
||||
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
|
||||
mm[k] = tmp.reshape(mm[k].shape)
|
||||
sss = src.squeeze().float().cpu().numpy()
|
||||
print(sss[:10], '...', sss[-10:])
|
||||
mmm = mm[k].squeeze().float().cpu().numpy()
|
||||
print(mmm[:10], '...', mmm[-10:])
|
||||
|
||||
print(f"Save to {init_weight_name}...")
|
||||
torch.save(mm, init_weight_name)
|
||||
|
||||
if model.args.my_pile_stage == 1:
|
||||
print("Done. Now go for stage 2.")
|
||||
exit(0)
|
||||
130
finetune/lora/src/utils.py
vendored
Normal file
130
finetune/lora/src/utils.py
vendored
Normal file
@@ -0,0 +1,130 @@
|
||||
import json, time, random, os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
time_slot = {}
|
||||
time_ref = time.time_ns()
|
||||
|
||||
def record_time(name):
|
||||
if name not in time_slot:
|
||||
time_slot[name] = 1e20
|
||||
tt = (time.time_ns() - time_ref) / 1e9
|
||||
if tt < time_slot[name]:
|
||||
time_slot[name] = tt
|
||||
|
||||
class TOKENIZER():
|
||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
||||
if 'list' in str(type(WORD_NAME)):
|
||||
self.charMode = False
|
||||
if WORD_NAME[0] == WORD_NAME[1]:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
||||
else:
|
||||
from transformers import GPT2TokenizerFast
|
||||
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
||||
self.vocab_size = len(self.tokenizer)
|
||||
else:
|
||||
self.charMode = True
|
||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
||||
self.word_table = json.load(result_file)
|
||||
|
||||
self.vocab_size = len(self.word_table)
|
||||
|
||||
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
||||
self.itos = {int(k): v for k, v in self.word_table.items()}
|
||||
|
||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||
|
||||
def refine_context(self, context):
|
||||
context = context.strip().split('\n')
|
||||
for c in range(len(context)):
|
||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
||||
context = list(filter(lambda c: c != '', context))
|
||||
context = '\n' + ('\n'.join(context)).strip()
|
||||
if context == '':
|
||||
context = '\n'
|
||||
return context
|
||||
|
||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||
lastChar = int(x[-1])
|
||||
|
||||
probs = F.softmax(out, dim=-1)
|
||||
|
||||
if self.charMode:
|
||||
if self.itos[lastChar] == '\n':
|
||||
top_p = top_p_newline
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
|
||||
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
|
||||
probs = probs.numpy()
|
||||
sorted_probs = np.sort(probs)[::-1]
|
||||
cumulative_probs = np.cumsum(sorted_probs)
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
probs[probs < cutoff] = 0
|
||||
if temperature != 1.0:
|
||||
probs = probs.pow(1.0 / temperature)
|
||||
probs = probs / np.sum(probs)
|
||||
out = np.random.choice(a=len(probs), p=probs)
|
||||
return out
|
||||
else:
|
||||
sorted_probs = torch.sort(probs, descending=True)[0]
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
probs[probs < cutoff] = 0
|
||||
if temperature != 1.0:
|
||||
probs = probs.pow(1.0 / temperature)
|
||||
out = torch.multinomial(probs, num_samples=1)[0]
|
||||
return out
|
||||
|
||||
def MaybeIsPrime(number):
|
||||
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def FermatPrimalityTest(number):
|
||||
if number > 1:
|
||||
for time in range(3):
|
||||
randomNumber = random.randint(2, number) - 1
|
||||
if pow(randomNumber, number - 1, number) != 1:
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def MillerRabinPrimalityTest(number):
|
||||
if number == 2:
|
||||
return True
|
||||
elif number == 1 or number % 2 == 0:
|
||||
return False
|
||||
oddPartOfNumber = number - 1
|
||||
timesTwoDividNumber = 0
|
||||
while oddPartOfNumber % 2 == 0:
|
||||
oddPartOfNumber = oddPartOfNumber // 2
|
||||
timesTwoDividNumber = timesTwoDividNumber + 1
|
||||
|
||||
for time in range(3):
|
||||
while True:
|
||||
randomNumber = random.randint(2, number) - 1
|
||||
if randomNumber != 0 and randomNumber != 1:
|
||||
break
|
||||
|
||||
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
|
||||
|
||||
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
||||
iterationNumber = 1
|
||||
|
||||
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
|
||||
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
||||
iterationNumber = iterationNumber + 1
|
||||
if randomNumberWithPower != (number - 1):
|
||||
return False
|
||||
|
||||
return True
|
||||
388
finetune/lora/train.py
vendored
Normal file
388
finetune/lora/train.py
vendored
Normal file
@@ -0,0 +1,388 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentParser
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
|
||||
rank_zero_info("########## work in progress ##########")
|
||||
|
||||
########################################################################################################
|
||||
#
|
||||
# example: train a simple L12-D768 RWKV on dummy data
|
||||
#
|
||||
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
||||
# --data_file "" --data_type "dummy" --vocab_size 0 \
|
||||
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
|
||||
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
|
||||
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
||||
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
||||
|
||||
# example: train a simple L6-D512 RWKV from scratch on enwik8
|
||||
#
|
||||
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
||||
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
|
||||
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
|
||||
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
|
||||
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
||||
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
||||
|
||||
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
|
||||
#
|
||||
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
||||
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
||||
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
|
||||
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
||||
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
||||
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
|
||||
|
||||
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
|
||||
#
|
||||
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
||||
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
||||
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
|
||||
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
||||
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
||||
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
||||
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
|
||||
parser.add_argument("--proj_dir", default="out", type=str)
|
||||
parser.add_argument("--random_seed", default="-1", type=int)
|
||||
|
||||
parser.add_argument("--data_file", default="", type=str)
|
||||
parser.add_argument("--data_type", default="utf-8", type=str)
|
||||
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
||||
|
||||
parser.add_argument("--ctx_len", default=1024, type=int)
|
||||
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
|
||||
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
|
||||
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
|
||||
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
|
||||
|
||||
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
||||
parser.add_argument("--n_layer", default=6, type=int)
|
||||
parser.add_argument("--n_embd", default=512, type=int)
|
||||
parser.add_argument("--dim_att", default=0, type=int)
|
||||
parser.add_argument("--dim_ffn", default=0, type=int)
|
||||
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
|
||||
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
||||
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
||||
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
|
||||
|
||||
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
||||
parser.add_argument("--lr_final", default=1e-5, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
|
||||
parser.add_argument("--beta1", default=0.9, type=float)
|
||||
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
|
||||
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
||||
|
||||
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
||||
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
||||
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
|
||||
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
||||
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
||||
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
||||
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
||||
|
||||
parser.add_argument("--my_img_version", default=0, type=str)
|
||||
parser.add_argument("--my_img_size", default=0, type=int)
|
||||
parser.add_argument("--my_img_bit", default=0, type=int)
|
||||
parser.add_argument("--my_img_clip", default='x', type=str)
|
||||
parser.add_argument("--my_img_clip_scale", default=1, type=float)
|
||||
parser.add_argument("--my_img_l1_scale", default=0, type=float)
|
||||
parser.add_argument("--my_img_encoder", default='x', type=str)
|
||||
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
|
||||
parser.add_argument("--my_sample_len", default=0, type=int)
|
||||
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
||||
parser.add_argument("--my_att_shift", default=1, type=int)
|
||||
parser.add_argument("--my_pos_emb", default=0, type=int)
|
||||
parser.add_argument("--load_partial", default=0, type=int)
|
||||
parser.add_argument("--magic_prime", default=0, type=int)
|
||||
parser.add_argument("--my_qa_mask", default=0, type=int)
|
||||
parser.add_argument("--my_testing", default='', type=str)
|
||||
|
||||
parser.add_argument("--lora", action="store_true")
|
||||
parser.add_argument("--lora_load", default="", type=str)
|
||||
parser.add_argument("--lora_r", default=8, type=int)
|
||||
parser.add_argument("--lora_alpha", default=32, type=float)
|
||||
parser.add_argument("--lora_dropout", default=0.01, type=float)
|
||||
parser.add_argument("--lora_parts", default="att,ln,time", type=str)
|
||||
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
########################################################################################################
|
||||
|
||||
import os, warnings, math, datetime, sys, time, importlib
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
if "deepspeed" in args.strategy:
|
||||
import deepspeed
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
if args.random_seed >= 0:
|
||||
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
|
||||
seed_everything(args.random_seed)
|
||||
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
||||
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
||||
# os.environ["WDS_SHOW_SEED"] = "1"
|
||||
|
||||
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
args.enable_checkpointing = False
|
||||
args.replace_sampler_ddp = False
|
||||
args.logger = False
|
||||
args.gradient_clip_val = 1.0
|
||||
args.num_sanity_val_steps = 0
|
||||
args.check_val_every_n_epoch = int(1e20)
|
||||
args.log_every_n_steps = int(1e20)
|
||||
args.max_epochs = -1 # continue forever
|
||||
args.betas = (args.beta1, args.beta2)
|
||||
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
||||
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
||||
os.environ["RWKV_MY_TESTING"] = args.my_testing
|
||||
if args.dim_att <= 0:
|
||||
args.dim_att = args.n_embd
|
||||
if args.dim_ffn <= 0:
|
||||
args.dim_ffn = args.n_embd * 4
|
||||
|
||||
if args.data_type == "wds_img":
|
||||
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
||||
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
||||
else:
|
||||
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
||||
if not os.path.exists(args.proj_dir):
|
||||
os.makedirs(args.proj_dir)
|
||||
|
||||
if args.my_pile_stage > 0:
|
||||
magic_prime_bak = args.magic_prime
|
||||
if args.ctx_len == 1024:
|
||||
args.magic_prime = 324331313
|
||||
args.epoch_count = 8043
|
||||
elif args.ctx_len == 2048:
|
||||
args.magic_prime = 162165671
|
||||
args.epoch_count = 4021
|
||||
elif args.ctx_len == 4096:
|
||||
args.magic_prime = 81082817
|
||||
args.epoch_count = 2010
|
||||
if args.my_pile_shift < 0:
|
||||
if args.ctx_len == 1024:
|
||||
args.my_pile_shift = 0
|
||||
elif args.ctx_len == 2048:
|
||||
args.my_pile_shift = 512
|
||||
elif args.ctx_len == 4096:
|
||||
args.my_pile_shift = 768
|
||||
|
||||
if magic_prime_bak > 0:
|
||||
args.magic_prime = magic_prime_bak
|
||||
|
||||
args.epoch_steps = 40320 // args.real_bsz
|
||||
assert args.epoch_steps * args.real_bsz == 40320
|
||||
if args.my_pile_stage == 2:
|
||||
assert args.lr_final == args.lr_init
|
||||
if args.my_pile_stage >= 2: # find latest saved model
|
||||
list_p = []
|
||||
for p in os.listdir(args.proj_dir):
|
||||
if p.startswith("rwkv") and p.endswith(".pth"):
|
||||
p = ((p.split("-"))[1].split("."))[0]
|
||||
if p == "init":
|
||||
p = -1
|
||||
else:
|
||||
p = int(p)
|
||||
list_p += [p]
|
||||
list_p.sort()
|
||||
max_p = list_p[-1]
|
||||
if len(list_p) > 1:
|
||||
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
|
||||
if max_p == -1:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||
else:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||
if args.my_pile_stage == 2:
|
||||
args.warmup_steps = 10
|
||||
else:
|
||||
args.warmup_steps = 30
|
||||
args.epoch_begin = max_p + 1
|
||||
|
||||
samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
||||
rank_zero_info(
|
||||
f"""
|
||||
############################################################################
|
||||
#
|
||||
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
|
||||
#
|
||||
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
||||
#
|
||||
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
|
||||
#
|
||||
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
||||
#
|
||||
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
||||
# LoRA = {f'enabled, {args.lora_r} r, {args.lora_alpha} alpha, {args.lora_dropout} dropout, on {args.lora_parts}' if args.lora else 'disabled'}
|
||||
#
|
||||
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
||||
#
|
||||
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
|
||||
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
|
||||
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
|
||||
#
|
||||
############################################################################
|
||||
"""
|
||||
)
|
||||
rank_zero_info(str(vars(args)) + "\n")
|
||||
|
||||
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
|
||||
|
||||
if args.lr_final == 0 or args.lr_init == 0:
|
||||
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
|
||||
|
||||
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
||||
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
||||
if args.precision == "fp32":
|
||||
for i in range(10):
|
||||
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
|
||||
if args.precision == "fp16":
|
||||
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
||||
|
||||
os.environ["RWKV_JIT_ON"] = "1"
|
||||
if "deepspeed_stage_3" in args.strategy:
|
||||
os.environ["RWKV_JIT_ON"] = "0"
|
||||
if args.lora and args.grad_cp == 1:
|
||||
print('!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it')
|
||||
os.environ["RWKV_JIT_ON"] = "0"
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
if args.precision == "fp32":
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
else:
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if "32" in args.precision:
|
||||
args.precision = 32
|
||||
elif args.precision == "fp16":
|
||||
args.precision = 16
|
||||
else:
|
||||
args.precision = "bf16"
|
||||
|
||||
########################################################################################################
|
||||
|
||||
from src.trainer import train_callback, generate_init_weight
|
||||
from src.dataset import MyDataset
|
||||
|
||||
train_data = MyDataset(args)
|
||||
args.vocab_size = train_data.vocab_size
|
||||
|
||||
if args.data_type == 'wds_img':
|
||||
from src.model_img import RWKV_IMG
|
||||
assert args.lora, "LoRA not yet supported for RWKV_IMG"
|
||||
model = RWKV_IMG(args)
|
||||
else:
|
||||
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
||||
if args.lora:
|
||||
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
||||
LORA_CONFIG["r"] = args.lora_r
|
||||
LORA_CONFIG["alpha"] = args.lora_alpha
|
||||
LORA_CONFIG["dropout"] = args.lora_dropout
|
||||
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(','))
|
||||
enable_time_finetune = 'time' in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
|
||||
model = RWKV(args)
|
||||
# only train lora parameters
|
||||
if args.lora:
|
||||
model.requires_grad_(False)
|
||||
for name, module in model.named_modules():
|
||||
# have to check param name since it may have been wrapped by torchscript
|
||||
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
|
||||
print(f' LoRA training module {name}')
|
||||
for pname, param in module.named_parameters():
|
||||
param.requires_grad = 'lora_' in pname
|
||||
elif enable_ln_finetune and '.ln' in name:
|
||||
print(f' LoRA additionally training module {name}')
|
||||
for param in module.parameters():
|
||||
param.requires_grad = True
|
||||
elif enable_time_finetune and any(n.startswith("time") for n, _ in module.named_parameters()):
|
||||
for pname, param in module.named_parameters():
|
||||
if pname.startswith("time"):
|
||||
print(f' LoRA additionally training parameter {pname}')
|
||||
param.requires_grad = True
|
||||
|
||||
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
|
||||
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
||||
generate_init_weight(model, init_weight_name) # save initial weights
|
||||
args.load_model = init_weight_name
|
||||
|
||||
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
||||
try:
|
||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
except:
|
||||
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
||||
if args.my_pile_stage >= 2: # try again using another checkpoint
|
||||
max_p = args.my_pile_prev_p
|
||||
if max_p == -1:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||
else:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||
args.epoch_begin = max_p + 1
|
||||
rank_zero_info(f"Trying {args.load_model}")
|
||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
|
||||
if args.load_partial == 1:
|
||||
load_keys = load_dict.keys()
|
||||
for k in model.state_dict():
|
||||
if k not in load_keys:
|
||||
load_dict[k] = model.state_dict()[k]
|
||||
# If using LoRA, the LoRA keys might be missing in the original model
|
||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
if os.path.isfile(args.lora_load):
|
||||
model.load_state_dict(torch.load(args.lora_load, map_location="cpu"),
|
||||
strict=False)
|
||||
|
||||
trainer: Trainer = Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[train_callback(args)],
|
||||
)
|
||||
|
||||
if (args.lr_init > 1e-4 or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8):
|
||||
if 'I_KNOW_WHAT_IM_DOING' in os.environ:
|
||||
if trainer.global_rank == 0:
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
print(f' WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
else:
|
||||
if trainer.global_rank == 0:
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
print(f' ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
|
||||
print(f' Unless you are sure this is what you want, adjust them accordingly')
|
||||
print(f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")')
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
exit(0)
|
||||
|
||||
if trainer.global_rank == 0:
|
||||
for n in model.state_dict():
|
||||
shape = model.state_dict()[n].shape
|
||||
shape = [i for i in shape if i != 1]
|
||||
if len(shape) > 1:
|
||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
|
||||
else:
|
||||
print(f"{str(shape[0]).ljust(5)} {n}")
|
||||
|
||||
if "deepspeed" in args.strategy:
|
||||
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
||||
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
||||
|
||||
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
||||
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
||||
|
||||
trainer.fit(model, data_loader)
|
||||
3
finetune/requirements.txt
Normal file
3
finetune/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
torch==1.13.1
|
||||
pytorch_lightning==1.9.5
|
||||
deepspeed
|
||||
Reference in New Issue
Block a user