From 333619839aed0d9e3b921ea30b926cd126541238 Mon Sep 17 00:00:00 2001 From: josc146 Date: Wed, 13 Mar 2024 17:51:53 +0800 Subject: [PATCH] rwkv6 lora finetune support (https://github.com/JL-er/RWKV-LORA) --- finetune/get_layer_and_embd.py | 6 +- finetune/install-wsl-dep-and-train.sh | 2 +- finetune/lora/v6/cuda/wkv5_cuda.cu | 202 +++++ finetune/lora/v6/cuda/wkv5_op.cpp | 22 + finetune/lora/v6/cuda/wkv6_cuda.cu | 242 ++++++ finetune/lora/v6/cuda/wkv6_op.cpp | 22 + finetune/lora/v6/src/__init__.py | 0 finetune/lora/v6/src/binidx.py | 303 +++++++ finetune/lora/v6/src/dataset.py | 242 ++++++ finetune/lora/v6/src/model.py | 1086 +++++++++++++++++++++++++ finetune/lora/v6/src/trainer.py | 310 +++++++ finetune/lora/v6/src/utils.py | 139 ++++ finetune/lora/v6/train.py | 435 ++++++++++ 13 files changed, 3009 insertions(+), 2 deletions(-) create mode 100644 finetune/lora/v6/cuda/wkv5_cuda.cu create mode 100644 finetune/lora/v6/cuda/wkv5_op.cpp create mode 100644 finetune/lora/v6/cuda/wkv6_cuda.cu create mode 100644 finetune/lora/v6/cuda/wkv6_op.cpp create mode 100644 finetune/lora/v6/src/__init__.py create mode 100644 finetune/lora/v6/src/binidx.py create mode 100644 finetune/lora/v6/src/dataset.py create mode 100644 finetune/lora/v6/src/model.py create mode 100644 finetune/lora/v6/src/trainer.py create mode 100644 finetune/lora/v6/src/utils.py create mode 100644 finetune/lora/v6/train.py diff --git a/finetune/get_layer_and_embd.py b/finetune/get_layer_and_embd.py index 04e501a..7ac268d 100644 --- a/finetune/get_layer_and_embd.py +++ b/finetune/get_layer_and_embd.py @@ -52,9 +52,13 @@ for x in keys: if "time_maa" in x: version = max(6, version) +params = f"--vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}" + if version <= expected_max_version: + if version == 6: + params += ' --my_testing "x060"' print( - f"v{int(version)}/train.py --vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}", + f"v{int(version)}/train.py {params}", end="", ) else: diff --git a/finetune/install-wsl-dep-and-train.sh b/finetune/install-wsl-dep-and-train.sh index f1d0d0a..7281c4b 100644 --- a/finetune/install-wsl-dep-and-train.sh +++ b/finetune/install-wsl-dep-and-train.sh @@ -53,7 +53,7 @@ else fi echo "loading $loadModel" -modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 5.2) +modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 6.0) echo $modelInfo if [[ $modelInfo =~ "--n_layer" ]]; then sudo rm -rf /root/.cache/torch_extensions diff --git a/finetune/lora/v6/cuda/wkv5_cuda.cu b/finetune/lora/v6/cuda/wkv5_cuda.cu new file mode 100644 index 0000000..3e6b859 --- /dev/null +++ b/finetune/lora/v6/cuda/wkv5_cuda.cu @@ -0,0 +1,202 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _w += h*_N_; + _u += h*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_] = {0}; + + __syncthreads(); + w[i] = _w[i]; + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _w += h*_N_; + _u += h*_N_; + __w += h*_N_; + + __shared__ float w_[_N_], u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_]; + __syncthreads(); + w_[i] = _w[i]; + u_[i] = float(_u[i]); + __syncthreads(); + + const float w = w_[i]; + const float ww = __w[i]; + const float u = u_[i]; + + float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; + + float gw = 0, gu = 0; + const int t000 = b*T*C + h*_N_ + i; + const int t111 = (b+1)*T*C + h*_N_ + i; + const int t222 = t111 - 2*C; + + for (int t = t000; t < t111; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t000; t < t222; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t + 2*C]); + __syncthreads(); + + const float k = float(_k[t]); + float gw_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + float& s2 = sbbbb[j]; + float x = k * v[j]; + + float tmp = w * (x + s); + s = tmp; + s2 = tmp + w * s2; + gw_ += s2 * gy[j]; + } + gw += float(_r[t + 2*C]) * gw_; + } + _gw[b*C + h*_N_ + i] = F(ww * gw); + + for (int t = t111 - C; t >= t000; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t111 - C; t >= t000; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); +} diff --git a/finetune/lora/v6/cuda/wkv5_op.cpp b/finetune/lora/v6/cuda/wkv5_op.cpp new file mode 100644 index 0000000..4c9ece1 --- /dev/null +++ b/finetune/lora/v6/cuda/wkv5_op.cpp @@ -0,0 +1,22 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv5 forward"); + m.def("backward", &backward, "wkv5 backward"); +} + +TORCH_LIBRARY(wkv5, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/finetune/lora/v6/cuda/wkv6_cuda.cu b/finetune/lora/v6/cuda/wkv6_cuda.cu new file mode 100644 index 0000000..7b7c836 --- /dev/null +++ b/finetune/lora/v6/cuda/wkv6_cuda.cu @@ -0,0 +1,242 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_] = {0}; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + w[i] = exp(_w[t]); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward_111(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + + __shared__ float u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; + __syncthreads(); + u_[i] = float(_u[i]); + __syncthreads(); + + const float u = u_[i]; + + float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; + + const int t_0 = b*T*C + h*_N_ + i; + const int t_T_1 = t_0 + (T-1)*C; + const int t_T = t_0 + T*C; + + float gu = 0; + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + const float w = exp(_w[t]); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + const float w = exp(_w[t]); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + w_[i] = exp(_w[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } +} + +template +__global__ void kernel_backward_222(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, + F *__restrict__ const _gw) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + + __shared__ float v[_N_], gy[_N_]; + float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0}; + + const int t_0 = b*T*C + h*_N_ + i; + const int t_1 = t_0 + C; + const int t_2 = t_0 + 2*C; + const int t_T_1 = t_0 + (T-1)*C; + + for (int t = t_T_1; t > t_1; t -= C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float r = float(_r[t]); + const float w = exp(_w[t-C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + float x = r * gy[j]; + s = (s + x) * w; + sum += s * v[j]; + } + sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]); + } + + float sss = sbbbb[0]; + _gw[t_0] = 0; + _gw[t_1] = F(sss * _w[t_1]); + + for (int t = t_2; t < t_T_1; t += C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float w = exp(_w[t-C]); + const float k = float(_k[t-2*C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = k * v[j]; + s = (s + x) * w; + sum += s * gy[j]; + } + sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t])); + _gw[t] = F(sss * _w[t]); + } + _gw[t_T_1] = 0; +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu); + kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, gy, gw); +} diff --git a/finetune/lora/v6/cuda/wkv6_op.cpp b/finetune/lora/v6/cuda/wkv6_op.cpp new file mode 100644 index 0000000..432ac56 --- /dev/null +++ b/finetune/lora/v6/cuda/wkv6_op.cpp @@ -0,0 +1,22 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv6 forward"); + m.def("backward", &backward, "wkv6 backward"); +} + +TORCH_LIBRARY(wkv6, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/finetune/lora/v6/src/__init__.py b/finetune/lora/v6/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/finetune/lora/v6/src/binidx.py b/finetune/lora/v6/src/binidx.py new file mode 100644 index 0000000..c2d60a1 --- /dev/null +++ b/finetune/lora/v6/src/binidx.py @@ -0,0 +1,303 @@ +from lib2to3.pgen2 import token +import os +import torch +import numpy as np +import shutil +import struct +from functools import lru_cache +from itertools import accumulate + + +def print_rank_0(*message): + pass + # """If distributed is initialized print only on rank 0.""" + # if torch.distributed.is_initialized(): + # if torch.distributed.get_rank() == 0: + # print(*message, flush=True) + # else: + # print(*message, flush=True) + + +def _warmup_mmap_file(path): + pass + # with open(path, "rb") as stream: + # while stream.read(100 * 1024 * 1024): + # pass + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: float, + 7: np.double, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +class MMapIndexedDataset(torch.utils.data.Dataset): + class Index(object): + _HDR_MAGIC = b"MMIDIDX\x00\x00" + + @classmethod + def writer(cls, path, dtype): + class _Writer(object): + def __enter__(self): + self._file = open(path, "wb") + + # Write Magic string so we can check the file format then opening it again. + self._file.write(cls._HDR_MAGIC) + # Write version number + # Little endian unsigned 64 Bit integer + self._file.write(struct.pack(" 0: + # self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document') + self.data_pile = MMapIndexedDataset( + "/fsx/pile_deduped/pile_0.87_deduped_text_document" + ) + self.data_pile_size = ( + len(self.data_pile._bin_buffer) // self.data._index._dtype_size + ) + else: + self.data_pile = None + self.data_pile_size = 0 + + if args.my_pile_stage > 0: + # assert self.data_size == 332115325534 and self.vocab_size == 50277 + self.samples_per_epoch = args.epoch_steps * args.real_bsz + assert self.samples_per_epoch == 40320 + rank_zero_info( + f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########" + ) + dataset_slot = self.data_size // args.ctx_len + if args.my_pile_stage != 4: + assert MaybeIsPrime(args.magic_prime) + assert args.magic_prime % 3 == 2 + assert ( + args.magic_prime / dataset_slot > 0.99 + and args.magic_prime / dataset_slot <= 1 + ) + elif args.data_type == "numpy": + self.data = np.load(args.data_file).astype("int") + self.vocab_size = args.vocab_size + rank_zero_info( + f"Current vocab size = {self.vocab_size} (make sure it's correct)" + ) + self.data_size = len(self.data) + rank_zero_info(f"Data has {self.data_size} tokens.") + elif args.data_type == "uint16": + self.data = ( + np.fromfile(args.data_file, dtype=np.uint16) + .astype("int32") + .reshape(-1, args.my_sample_len) + ) + self.vocab_size = args.vocab_size + rank_zero_info( + f"Current vocab size = {self.vocab_size} (make sure it's correct)" + ) + self.data_size = self.data.shape[0] + rank_zero_info(f"Data has {self.data_size} samples.") + else: + if args.data_type == "dummy": + rank_zero_info("Building dummy data...") + self.data = "" + for i in range(100000): + aa = (i) % 10000 + bb = (i * i) % 10000 + cc = aa + bb + self.data += f".{aa}+{bb}={cc}." + else: + self.data = open(args.data_file, "r", encoding=args.data_type).read() + rank_zero_info("Building token list...") + unique = sorted(list(set(self.data))) + self.vocab_size = len(unique) + # rank_zero_info() + # for u in unique: + # print(u, end=' ') + # rank_zero_info('\n\n') + xx = 0 + xxObj = {} + for u in unique: + xxObj[xx] = u + xx += 1 + with open( + f"{args.proj_dir}/vocab.json", "w", encoding="utf-8" + ) as vocab_file: + vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) + self.data_size = len(self.data) + rank_zero_info( + f"Data has {self.data_size} tokens, {self.vocab_size} vocab size." + ) + self.stoi = {ch: i for i, ch in enumerate(unique)} + self.itos = {i: ch for i, ch in enumerate(unique)} + + def __len__(self): + return self.args.epoch_steps * self.args.micro_bsz + + def __getitem__(self, idx): + args = self.args + rank = self.global_rank + epoch = self.real_epoch + world_size = self.world_size + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") + + if args.data_type == "uint16": + i = np.random.randint(0, self.data_size - 1) + dix = self.data[i] + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + else: + ctx_len = args.ctx_len + req_len = ctx_len + 1 + magic_prime = args.magic_prime + data = self.data + + if args.my_pile_stage > 0: + ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank + + if args.my_qa_mask > 0: + ii_orig = ii + if ii % 2 == 0: + ii = -1 + data = self.data_pile + else: + ii = ii // 2 + if data == self.data_pile: + i = np.random.randint(0, self.data_pile_size - req_len) + else: + if args.my_pile_stage == 4 or ii < args.my_random_steps: + # cheat: pick a random spot in dataset + if args.my_pile_version == 1: + i = np.random.randint(0, self.data_size - req_len) + else: + i = np.random.randint(0, self.data_size) + else: + ii = ii - args.my_random_steps + factor = (math.sqrt(5) - 1) / 2 + factor = int(magic_prime * factor) + i = ((factor * ii * ii * ii) % magic_prime) * ctx_len + i = i + args.my_pile_shift + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") + else: + # cheat: pick a random spot in dataset + i = np.random.randint(0, self.data_size - req_len) + + if args.data_type == "binidx": + if args.my_pile_version == 1: + dix = data.get(idx=0, offset=i, length=req_len).astype(int) + # dix = data.pad(idx=idx, length=req_len).astype(int) + else: + # self.data : cutoff, chunk_count, data + for j in range(len(data)): + if i < data[j][0]: + ii = i + i = (i - (data[j - 1][0] if j > 0 else 0)) % data[j][1] + dix = ( + data[j][2] + .get(idx=0, offset=i, length=req_len) + .astype(int) + ) + # print(ii, j, i) + break + elif args.data_type == "numpy": + dix = data[i : i + req_len] + else: + dix = [self.stoi[s] for s in data[i : i + req_len]] + + if args.my_qa_mask == 1: + if data == self.data_pile: + z = [1] * ctx_len + else: + z = [0] * ctx_len + z_sum = 0 + isGood = False + for i in range(3, ctx_len): + if ( + dix[i] == 27 + and dix[i - 1] == 34 + and dix[i - 2] == 187 + and dix[i - 3] == 187 + ): + isGood = True + if dix[i] == 0: + isGood = False + if isGood: + z[i] = 1 + z_sum += 1 + if z_sum == 0: + z = [1] * ctx_len + i = np.random.randint(0, self.data_pile_size - req_len) + dix = self.data_pile.get( + idx=0, offset=i, length=req_len + ).astype(int) + z = torch.tensor(z, dtype=torch.bfloat16) + + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + + # if ii_orig < 50: + # # if rank == 1: + # print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:]) + # else: + # exit(0) + + if args.my_qa_mask == 1: + return x, y, z + + return x, y diff --git a/finetune/lora/v6/src/model.py b/finetune/lora/v6/src/model.py new file mode 100644 index 0000000..a95b16a --- /dev/null +++ b/finetune/lora/v6/src/model.py @@ -0,0 +1,1086 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## +import functools +from torch.utils.checkpoint import checkpoint as torch_checkpoint + +import os, math, gc, importlib +import torch + +# torch._C._jit_set_profiling_executor(True) +# torch._C._jit_set_profiling_mode(True) +import torch.nn as nn +from torch.nn import functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from pytorch_lightning.strategies import DeepSpeedStrategy + +if importlib.util.find_spec("deepspeed"): + import deepspeed + from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam + + +LORA_CONFIG = { + "r": 0, + "alpha": 0, + "dropout": 0, + "parts": {"att", "ln", "time", "ffn"}, +} + + +class LoraLinear(nn.Module): + + def __init__(self, in_features: int, out_features: int, bias: bool): + super().__init__() + + self.weight = nn.Parameter(torch.empty((out_features, in_features))) + assert bias == False, "Biased LoraLinear not supported" + + r, alpha, dropout = ( + LORA_CONFIG["r"], + LORA_CONFIG["alpha"], + LORA_CONFIG["dropout"], + ) + self.lora_A = nn.Parameter(torch.empty(r, in_features)) + self.lora_B = nn.Parameter(torch.empty(out_features, r)) + self.lora_dropout = nn.Dropout(dropout) + self.scaling = alpha / r + + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def forward(self, x): + return F.linear(x, self.weight) + self.scaling * F.linear( + F.linear(self.lora_dropout(x), self.lora_A), self.lora_B + ) + + +@functools.wraps(LoraLinear) +def make_linear_att(*args, **kwargs): + if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: + return LoraLinear(*args, **kwargs) + else: + return nn.Linear(*args, **kwargs) + + +@functools.wraps(LoraLinear) +def make_linear_ffn(*args, **kwargs): + if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: + return LoraLinear(*args, **kwargs) + else: + return nn.Linear(*args, **kwargs) + + +try: + print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"]) +except: + os.environ["RWKV_MY_TESTING"] = "" + + +def __nop(ob): + return ob + + +MyModule = nn.Module +MyFunction = __nop +if os.environ["RWKV_JIT_ON"] == "1": + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +from torch.utils.cpp_extension import load + +HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"]) + +if "x060" in os.environ["RWKV_MY_TESTING"]: + wkv6_cuda = load( + name="wkv6", + sources=[ + "finetune/lora/v6/cuda/wkv6_op.cpp", + f"finetune/lora/v6/cuda/wkv6_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + f"-D_T_={int(os.environ['RWKV_CTXLEN'])}", + ], + ) + + class WKV_6(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, r, k, v, w, u): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + ew = (-torch.exp(w.float())).contiguous() + ctx.save_for_backward(r, k, v, ew, u) + y = torch.empty( + (B, T, C), + device=r.device, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, ew, u = ctx.saved_tensors + gr = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gk = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gv = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gw = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gu = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu) + gu = torch.sum(gu, 0).view(H, C // H) + return (None, None, None, None, gr, gk, gv, gw, gu) + + def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u): + return WKV_6.apply(B, T, C, H, r, k, v, w, u) + +else: + wkv5_cuda = load( + name="wkv5", + sources=[ + "finetune/lora/v6/cuda/wkv5_op.cpp", + f"finetune/lora/v6/cuda/wkv5_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + ], + ) + + class WKV_5(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, r, k, v, w, u): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + ew = (-torch.exp(w.float())).contiguous() + eew = (torch.exp(ew)).contiguous() + ctx.save_for_backward(r, k, v, eew, ew, u) + y = torch.empty( + (B, T, C), + device=r.device, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, eew, ew, u = ctx.saved_tensors + gr = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gk = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gv = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gw = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gu = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + wkv5_cuda.backward( + B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu + ) + gw = torch.sum(gw, 0).view(H, C // H) + gu = torch.sum(gu, 0).view(H, C // H) + return (None, None, None, None, gr, gk, gv, gw, gu) + + def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u): + return WKV_5.apply(B, T, C, H, r, k, v, w, u) + + +######################################################################################################## + + +class RWKV_TimeMix_RWKV5(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + + self.head_size = args.head_size_a + assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a + self.n_head = args.dim_att // self.head_size + assert args.dim_att % self.n_head == 0 + self.head_size_divisor = args.head_size_divisor + + with torch.no_grad(): + ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + + # fancy time_mix + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_v = nn.Parameter( + torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) + self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + + # fancy time_decay + decay_speed = torch.ones(args.dim_att) + for n in range(args.dim_att): + decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) + self.time_decay = nn.Parameter( + decay_speed.reshape(self.n_head, self.head_size) + ) + # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) + + tmp = torch.zeros(args.dim_att) + for n in range(args.dim_att): + zigzag = ((n + 1) % 3 - 1) * 0.1 + tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag + + self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.key = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.value = make_linear_att(args.n_embd, args.dim_att, bias=False) + + self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) + self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.ln_x = nn.GroupNorm(self.n_head, args.dim_att) + + @MyFunction + def jit_func(self, x): + B, T, C = x.size() + + xx = self.time_shift( + x + ) # Mix x with the previous timestep to produce xk, xv, xr + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + xg = x * self.time_mix_g + xx * (1 - self.time_mix_g) + + r = self.receptance(xr) + k = self.key(xk) + v = self.value(xv) + g = F.silu(self.gate(xg)) + + return r, k, v, g + + @MyFunction + def jit_func_2(self, x, g): + B, T, C = x.size() + x = x.view(B * T, C) + + x = self.ln_x(x / self.head_size_divisor).view(B, T, C) + x = self.output(x * g) + return x + + def forward(self, x): + B, T, C = x.size() + H = self.n_head + + r, k, v, g = self.jit_func(x) + + x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa) + + return self.jit_func_2(x, g) + + +class RWKV_Tmix_x060(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + + self.head_size = args.head_size_a + self.n_head = args.dim_att // self.head_size + assert args.dim_att % self.n_head == 0 + + with torch.no_grad(): + ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + + # fancy time_mix + self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_v = nn.Parameter( + 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + ) + self.time_maa_r = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + self.time_maa_g = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + + TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g + self.time_maa_w1 = nn.Parameter( + torch.zeros(args.n_embd, TIME_MIX_EXTRA_DIM * 5).uniform_(-1e-4, 1e-4) + ) + self.time_maa_w2 = nn.Parameter( + torch.zeros(5, TIME_MIX_EXTRA_DIM, args.n_embd).uniform_(-1e-4, 1e-4) + ) + + # fancy time_decay + decay_speed = torch.ones(args.dim_att) + for n in range(args.dim_att): + decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) + self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att)) + + TIME_DECAY_EXTRA_DIM = 64 + self.time_decay_w1 = nn.Parameter( + torch.zeros(args.n_embd, TIME_DECAY_EXTRA_DIM).uniform_(-1e-4, 1e-4) + ) + self.time_decay_w2 = nn.Parameter( + torch.zeros(TIME_DECAY_EXTRA_DIM, args.dim_att).uniform_(-1e-4, 1e-4) + ) + + tmp = torch.zeros(args.dim_att) + for n in range(args.dim_att): + zigzag = ((n + 1) % 3 - 1) * 0.1 + tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag + + self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.key = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.value = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) + self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.ln_x = nn.GroupNorm( + self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2) + ) + + @MyFunction + def jit_func(self, x): + B, T, C = x.size() + + xx = self.time_shift(x) - x + + xxx = x + xx * self.time_maa_x + xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + xw = x + xx * (self.time_maa_w + mw) + xk = x + xx * (self.time_maa_k + mk) + xv = x + xx * (self.time_maa_v + mv) + xr = x + xx * (self.time_maa_r + mr) + xg = x + xx * (self.time_maa_g + mg) + + r = self.receptance(xr) + k = self.key(xk) + v = self.value(xv) + g = F.silu(self.gate(xg)) + + ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 + w = self.time_decay + ww + + return r, k, v, g, w + + @MyFunction + def jit_func_2(self, x, g): + B, T, C = x.size() + x = x.view(B * T, C) + + x = self.ln_x(x).view(B, T, C) + x = self.output(x * g) + return x + + def forward(self, x): + B, T, C = x.size() + H = self.n_head + + r, k, v, g, w = self.jit_func(x) + x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa) + + return self.jit_func_2(x, g) + + +######################################################################################################## + + +class RWKV_ChannelMix(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + + self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False) + self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False) + self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + k = self.key(xk) + k = torch.relu(k) ** 2 + kv = self.value(k) + return torch.sigmoid(self.receptance(xr)) * kv + + +class RWKV_CMix_x060(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + + self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False) + self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False) + self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) - x + xk = x + xx * self.time_maa_k + xr = x + xx * self.time_maa_r + + k = self.key(xk) + k = torch.relu(k) ** 2 + kv = self.value(k) + return torch.sigmoid(self.receptance(xr)) * kv + + +######################################################################################################## + + +class MishGLU(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) + + x = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + x[0, 0, i] = i / args.n_embd + + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) + xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) + a = self.aa(xa) + b = self.bb(xb) + return self.value(a * F.mish(b)) + + +######################################################################################################## +# The RWKV Model with our blocks +######################################################################################################## + + +class Block(nn.Module): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + + self.ln1 = nn.LayerNorm(args.n_embd) + self.ln2 = nn.LayerNorm(args.n_embd) + + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(args.n_embd) + if args.my_pos_emb > 0: + self.pos_emb_x = nn.Parameter( + torch.zeros((1, args.my_pos_emb, args.n_embd)) + ) + self.pos_emb_y = nn.Parameter( + torch.zeros((args.my_pos_emb, 1, args.n_embd)) + ) + + if self.layer_id == 0 and self.args.pre_ffn > 0: + self.ffnPre = RWKV_ChannelMix(args, 0) + else: + if "x060" in os.environ["RWKV_MY_TESTING"]: + self.att = RWKV_Tmix_x060(args, layer_id) + else: + self.att = RWKV_TimeMix_RWKV5(args, layer_id) + + if "g" in os.environ["RWKV_MY_TESTING"]: + self.ffn = MishGLU(args, layer_id) + else: + if "x060" in os.environ["RWKV_MY_TESTING"]: + self.ffn = RWKV_CMix_x060(args, layer_id) + else: + self.ffn = RWKV_ChannelMix(args, layer_id) + + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + self.tiny_ln = nn.LayerNorm(args.n_embd) + self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) + self.register_buffer( + "tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) + + if args.dropout > 0: + self.drop0 = nn.Dropout(p=args.dropout) + self.drop1 = nn.Dropout(p=args.dropout) + + def forward(self, x, x_emb=None): + args = self.args + B, T, C = x.size() + if self.layer_id == 0: + x = self.ln0(x) + if args.my_pos_emb > 0: + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :] + x = x + pos_emb + + if self.args.dropout == 0: + if self.layer_id == 0 and args.pre_ffn > 0: + x = x + self.ffnPre(self.ln1(x)) + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + else: + if self.layer_id == 0 and args.pre_ffn > 0: + x = self.drop0(x + self.ffnPre(self.ln1(x))) + else: + x = self.drop0(x + self.att(self.ln1(x))) + x = self.drop1(x + self.ffn(self.ln2(x))) + + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + xx = self.tiny_ln(x) + q = self.tiny_q(xx)[:, :T, :] + k = self.tiny_k(xx)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) + c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) + x = x + c @ self.tiny_v(x_emb) + return x + + +class L2Wrap(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, y): + ctx.save_for_backward(y) + return loss + + @staticmethod + def backward(ctx, grad_output): + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + factor = 1e-4 / (y.shape[0] * y.shape[1]) + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy) + + +class RWKV(pl.LightningModule): + def __init__(self, args): + super().__init__() + self.args = args + if not hasattr(args, "dim_att"): + args.dim_att = args.n_embd + if not hasattr(args, "dim_ffn"): + args.dim_ffn = args.n_embd * 4 + if not hasattr(args, "tiny_att_layer"): + args.tiny_att_layer = -1 + if not hasattr(args, "tiny_att_dim"): + args.tiny_att_dim = -1 + assert args.n_embd % 32 == 0 + assert args.dim_att % 32 == 0 + assert args.dim_ffn % 32 == 0 + + self.emb = nn.Embedding(args.vocab_size, args.n_embd) + + self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) + + self.ln_out = nn.LayerNorm(args.n_embd) + self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) + + if args.head_qk > 0: + self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False) + self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) + self.register_buffer( + "copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) + if args.dropout > 0: + self.drop0 = nn.Dropout(p=args.dropout) + + def configure_optimizers(self): + args = self.args + + lr_decay = set() + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if (("_w1" in n) or ("_w2" in n)) and (args.layerwise_lr > 0): + lr_1x.add(n) + elif (("time_mix" in n) or ("time_maa" in n)) and (args.layerwise_lr > 0): + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif (("time_decay" in n) or ("time_daaaa" in n)) and ( + args.layerwise_lr > 0 + ): + if args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) + elif ("time_faaaa" in n) and (args.layerwise_lr > 0): + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif ("time_first" in n) and (args.layerwise_lr > 0): + lr_3x.add(n) + elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0): + lr_decay.add(n) + else: + lr_1x.add(n) + + lr_decay = sorted(list(lr_decay)) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + # print('decay', lr_decay) + # print('1x', lr_1x) + # print('2x', lr_2x) + # print('3x', lr_3x) + param_dict = {n: p for n, p in self.named_parameters()} + + if args.layerwise_lr > 0: + if args.my_pile_stage == 2: + optim_groups = [ + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 2e-3 / args.lr_init}, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 3e-3 / args.lr_init}, + ] + else: + optim_groups = [ + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 2.0, + }, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 3.0, + }, + ] + else: + optim_groups = [ + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + } + ] + + if args.weight_decay > 0: + optim_groups += [ + { + "params": [param_dict[n] for n in lr_decay], + "weight_decay": args.weight_decay, + "my_lr_scale": 1.0, + } + ] + if self.deepspeed_offload: + return DeepSpeedCPUAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adamw_mode=True, + amsgrad=False, + ) + return FusedAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adam_w_mode=True, + amsgrad=False, + ) + else: + if self.deepspeed_offload: + return DeepSpeedCPUAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adamw_mode=False, + weight_decay=0, + amsgrad=False, + ) + return FusedAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adam_w_mode=False, + weight_decay=0, + amsgrad=False, + ) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + cfg = strategy.config["zero_optimization"] + return cfg.get("offload_optimizer") or cfg.get("offload_param") + return False + + def forward(self, idx): + args = self.args + B, T = idx.size() + assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." + + x = self.emb(idx) + x_emb = x + + if args.dropout > 0: + x = self.drop0(x) + if args.tiny_att_dim > 0: + for block in self.blocks: + if args.grad_cp == 1: + if args.lora: + x = torch_checkpoint(block, x, x_emb, use_reentrant=False) + else: + x = deepspeed.checkpointing.checkpoint(block, x, x_emb) + else: + x = block(x, x_emb) + else: + for block in self.blocks: + if args.grad_cp == 1: + if args.lora: + x = torch_checkpoint(block, x, x_emb, use_reentrant=False) + else: + x = deepspeed.checkpointing.checkpoint(block, x) + else: + x = block(x) + + x = self.ln_out(x) + + if args.head_qk > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + if "32" in os.environ["RWKV_FLOAT_MODE"]: + c = c @ F.one_hot(idx, num_classes=args.vocab_size) + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() + + x = self.head(x) + c + else: + x = self.head(x) + + return x + + def training_step(self, batch, batch_idx): + args = self.args + if args.my_qa_mask != 1: + idx, targets = batch + logits = self(idx) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + # if '0' in os.environ["RWKV_MY_TESTING"]: + # print('logits', logits) + # torch.set_printoptions(threshold=10000) + # print('idx', idx) + # exit(0) + else: + idx, targets, mask = batch + mask = mask.view(-1) + sum_mask = torch.sum(mask).item() + # if sum_mask == 0: + # return torch.tensor([0.0], requires_grad=True) + + logits = self(idx) + if sum_mask == mask.shape[0]: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1) + ) + # print('rank', self.global_rank, 'loss', loss.item()) + else: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none" + ) + # loss_raw = loss + loss = torch.sum(loss * mask) / sum_mask + + # torch.set_printoptions(threshold=10000) + # if True: #self.global_rank == 1: + # tmp = '' + # sss = 0 + # ccc = 0 + # for i in range(mask.shape[0]): + # if mask[i] > 0: + # tmp += str(idx.view(-1)[i].item()) + ',' + # sss += loss_raw.view(-1)[i].float().item() + # ccc += 1 + # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) + + return L2Wrap.apply(loss, logits) + + def training_step_end(self, batch_parts): + if pl.__version__[0] != "2": + all = self.all_gather(batch_parts) + if self.trainer.is_global_zero: + self.trainer.my_loss_all = all + + def generate_init_weight(self): + print( + f""" +############################################################################ +# +# Init model weight (slow for large models)... +# +############################################################################ +""" + ) + m = {} + for n in self.state_dict(): + p = self.state_dict()[n] + shape = p.shape + + gain = 1.0 + scale = 1.0 + if ( + "ln_" in n + or ".ln" in n + or "time_" in n + or "_mask" in n + or "pos_emb" in n + or ".mask." in n + ): + if "ln_x.weight" in n: + layer_scale = (1 + int(n.split(".")[1])) / self.args.n_layer + m[n] = (p * 0.0) + (layer_scale**0.7) + else: + m[n] = p + else: + if n == "emb.weight": + scale = -1 * self.args.lr_init + else: + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + + zero = [ + ".att.output.", + ".ffn.value.", + ".ffn.receptance.", + ".ffnPre.value.", + ".ffnPre.receptance.", + "head_q.", + ".oo.", + ".rr.", + ] + + for kk in zero: + if kk in n: + scale = 0 + if n == "head.weight": + scale = 0.5 + if "head_k." in n: + scale = 0.1 + if "head_q." in n: + scale = 0 + + print( + f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}" + ) + + if self.args.accelerator.upper() == "GPU": + m[n] = torch.empty((shape[0], shape[1]), device="cuda") + else: + m[n] = torch.empty((shape[0], shape[1])) + + if scale == 0: + nn.init.zeros_(m[n]) + elif scale < 0: + nn.init.uniform_(m[n], a=scale, b=-scale) + else: + nn.init.orthogonal_(m[n], gain=gain * scale) + + m[n] = m[n].cpu() + if os.environ["RWKV_FLOAT_MODE"] == "fp16": + m[n] = m[n].half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + m[n] = m[n].bfloat16() + + # if n == "emb.weight": + # print(m[n]) + + gc.collect() + torch.cuda.empty_cache() + return m diff --git a/finetune/lora/v6/src/trainer.py b/finetune/lora/v6/src/trainer.py new file mode 100644 index 0000000..e14e7fc --- /dev/null +++ b/finetune/lora/v6/src/trainer.py @@ -0,0 +1,310 @@ +import os, math, time, datetime, subprocess +import torch +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from .model import LORA_CONFIG + + +def my_save(args, trainer, dd, ff): + if "14b-run1" in ff: + fn = ff.split("/")[-1] + fff = "/dev/shm/" + fn + torch.save(dd, fff) + subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True) + elif ("world/14b" in ff) or ("world/7b" in ff): + aa = ff.split("/")[1] + fn = ff.split("/")[-1] + fff = f"/dev/shm/{aa}-{fn}" + torch.save(dd, fff) + subprocess.Popen( + f" aws s3 mv {fff} s3://rwkv-world/{aa}-{fn} --quiet", shell=True + ) + else: + if "deepspeed_stage_3" in args.strategy: + trainer.save_checkpoint(ff, weights_only=True) + else: + torch.save(dd, ff) + + +class train_callback(pl.Callback): + def __init__(self, args): + super().__init__() + self.args = args + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + args = self.args + # if args.cuda_cleanup > 0: + # torch.cuda.empty_cache() + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + + # LR schedule + w_step = args.warmup_steps + if args.lr_final == args.lr_init or args.epoch_count == 0: + lr = args.lr_init + else: + decay_step = real_step - args.my_pile_edecay * args.epoch_steps + decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps + progress = (decay_step - w_step + 1) / (decay_total - w_step) + progress = min(1, max(0, progress)) + + if args.lr_final == 0 or args.lr_init == 0: # linear decay + lr = args.lr_init + (args.lr_final - args.lr_init) * progress + else: # exp decay + lr = args.lr_init * math.exp( + math.log(args.lr_final / args.lr_init) * pow(progress, 1) + ) + # if trainer.is_global_zero: + # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) + + if args.my_exit_tokens != 0: # cosine decay + real_tokens = real_step * args.ctx_len * args.real_bsz + warmup_tokens = w_step * args.ctx_len * args.real_bsz + progress = (real_tokens - warmup_tokens) / ( + abs(args.my_exit_tokens) - warmup_tokens + ) + progress = max(0, min(1, progress)) + lr_final_factor = args.lr_final / args.lr_init + lr_mult = (0.5 + lr_final_factor / 2) + ( + 0.5 - lr_final_factor / 2 + ) * math.cos(math.pi * progress) + if args.my_exit_tokens > 0: + lr = args.lr_init * lr_mult + else: + lr = (lr + args.lr_init * lr_mult) / 2 + if progress >= 1: + if (trainer.is_global_zero) or ("deepspeed_stage_3" in args.strategy): + my_save( + args, + trainer, + pl_module.state_dict(), + f"{args.proj_dir}/rwkv-final.pth", + ) + exit(0) + if trainer.global_step < w_step: + lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) + + if args.weight_decay_final > 0: + wd_now = args.weight_decay * math.exp( + math.log(args.weight_decay_final / args.weight_decay) * progress + ) + else: + wd_now = args.weight_decay + + for param_group in trainer.optimizers[0].param_groups: + if param_group["weight_decay"] > 0: + param_group["weight_decay"] = wd_now + if args.layerwise_lr > 0: + param_group["lr"] = lr * param_group["my_lr_scale"] + # print(param_group["lr"], param_group["my_lr_scale"]) + else: + param_group["lr"] = lr + + trainer.my_lr = lr + trainer.my_wd = wd_now + # rank_zero_info(f"{real_step} {lr}") + + if trainer.global_step == 0: + if trainer.is_global_zero: # logging + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") + trainer.my_log.write( + f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n" + ) + try: + print(f"\n{trainer.strategy.config}\n") + trainer.my_log.write(f"{trainer.strategy.config}\n") + except: + pass + trainer.my_log.flush() + if len(args.wandb) > 0: + print("Login to wandb...") + import wandb + + wandb.init( + project=args.wandb, + name=args.run_name + " " + args.my_timestamp, + config=args, + save_code=False, + ) + trainer.my_wandb = wandb + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + args = self.args + token_per_step = args.ctx_len * args.real_bsz + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + if trainer.is_global_zero: # logging + t_now = time.time_ns() + kt_s = 0 + try: + t_cost = (t_now - trainer.my_time_ns) / 1e9 + kt_s = token_per_step / t_cost / 1000 + self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) + self.log("Kt/s", kt_s, prog_bar=True, on_step=True) + except: + pass + trainer.my_time_ns = t_now + if pl.__version__[0] == "2": + trainer.my_loss = outputs["loss"] + else: + trainer.my_loss = trainer.my_loss_all.float().mean().item() + trainer.my_loss_sum += trainer.my_loss + trainer.my_loss_count += 1 + trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count + self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) + self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) + # self.log("s", real_step, prog_bar=True, on_step=True) + + if len(args.wandb) > 0: + lll = { + "loss": trainer.my_loss, + "lr": trainer.my_lr, + "wd": trainer.my_wd, + "Gtokens": real_step * token_per_step / 1e9, + } + if kt_s > 0: + lll["kt/s"] = kt_s + trainer.my_wandb.log(lll, step=int(real_step)) + if (trainer.is_global_zero) or ( + "deepspeed_stage_3" in args.strategy + ): # save pth + if args.magic_prime > 0: + expand_factor = 2 if args.my_qa_mask > 0 else 1 + if int(real_step) == int( + args.magic_prime * expand_factor // args.real_bsz + ) - 1 + int(args.my_random_steps): + to_save_dict = pl_module.state_dict() + my_save( + args, + trainer, + to_save_dict, + f"{args.proj_dir}/rwkv-final.pth", + ) + # if args.batch_save==batch_idx : + # to_save_dict = pl_module.state_dict() + # for name, state in to_save_dict.items(): + # if 'img' in name: + # to_save_dict[name] = state + # try: + # my_save( + # args, trainer, + # to_save_dict, + # f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}-{batch_idx}.pth", + # ) + # except Exception as e: + # print('Error\n\n', e, '\n\n') + + def on_train_epoch_start(self, trainer, pl_module): + args = self.args + if pl.__version__[0] == "2": + dataset = trainer.train_dataloader.dataset + else: + dataset = trainer.train_dataloader.dataset.datasets + assert "MyDataset" in str(dataset) + dataset.global_rank = trainer.global_rank + dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) + dataset.world_size = trainer.world_size + # print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########') + + def on_train_epoch_end(self, trainer, pl_module): + args = self.args + to_save_dict = {} + if (trainer.is_global_zero) or ( + "deepspeed_stage_3" in args.strategy + ): # save pth + if ( + args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0 + ) or (trainer.current_epoch == args.epoch_count - 1): + if args.data_type == "wds_img": + raw_dict = pl_module.state_dict() + for k in raw_dict: + if k.startswith("encoder.") or k.startswith("decoder."): + to_save_dict[k] = raw_dict[k] + else: + to_save_dict = pl_module.state_dict() + + if args.data_type == "img" and not args.lora: + for name, state in to_save_dict.items(): + if "img" in name: + to_save_dict[name] = state + + if args.lora: + enable_time_finetune = "time" in LORA_CONFIG["parts"] + enable_ln_finetune = "ln" in LORA_CONFIG["parts"] + lora_dict = {} + for name, state in to_save_dict.items(): + if "img" in name: + lora_dict[name] = state + if ( + ".lora_" in name + or (enable_time_finetune and ".time_" in name) + or (enable_ln_finetune and ".ln" in name) + ): + lora_dict[name] = state + to_save_dict = lora_dict + + try: + my_save( + args, + trainer, + to_save_dict, + f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", + ) + except Exception as e: + print("Error\n\n", e, "\n\n") + + if trainer.is_global_zero: # logging + trainer.my_log.write( + f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n" + ) + trainer.my_log.flush() + + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + if (args.epoch_begin + trainer.current_epoch) >= args.my_exit: + exit(0) + + +@rank_zero_only +def generate_init_weight(model, init_weight_name): + mm = model.generate_init_weight() + + if model.args.my_pile_stage == 1: + if len(model.args.load_model) > 0: + print(f"Combine weights from {model.args.load_model}...") + load_dict = torch.load(model.args.load_model, map_location="cpu") + for k in load_dict: + try: + assert k in mm + except: + print("missing", k) + exit(0) + src = load_dict[k] + try: + mm[k] = src.reshape(mm[k].shape) + except: + tmp = mm[k].squeeze().clone() + print(k, src.shape, "-->", mm[k].shape) + ss = src.shape[0] + dd = tmp.shape[0] + for i in range(dd): + pos = i / dd * ss + if pos >= ss - 1: + tmp[i] = src[ss - 1] + else: + p0 = int(math.floor(pos)) + ii = pos - p0 + tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii) + mm[k] = tmp.reshape(mm[k].shape) + sss = src.squeeze().float().cpu().numpy() + print(sss[:10], "...", sss[-10:]) + mmm = mm[k].squeeze().float().cpu().numpy() + print(mmm[:10], "...", mmm[-10:]) + + print(f"Save to {init_weight_name}...") + torch.save(mm, init_weight_name) + + if model.args.my_pile_stage == 1: + print("Done. Now go for stage 2.") + exit(0) diff --git a/finetune/lora/v6/src/utils.py b/finetune/lora/v6/src/utils.py new file mode 100644 index 0000000..87da098 --- /dev/null +++ b/finetune/lora/v6/src/utils.py @@ -0,0 +1,139 @@ +import json, time, random, os +import numpy as np +import torch +from torch.nn import functional as F + +time_slot = {} +time_ref = time.time_ns() + + +def record_time(name): + if name not in time_slot: + time_slot[name] = 1e20 + tt = (time.time_ns() - time_ref) / 1e9 + if tt < time_slot[name]: + time_slot[name] = tt + + +class TOKENIZER: + def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"): + if "list" in str(type(WORD_NAME)): + self.charMode = False + if WORD_NAME[0] == WORD_NAME[1]: + from transformers import PreTrainedTokenizerFast + + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) + else: + from transformers import GPT2TokenizerFast + + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) + self.vocab_size = len(self.tokenizer) + else: + self.charMode = True + with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file: + self.word_table = json.load(result_file) + + self.vocab_size = len(self.word_table) + + self.stoi = {v: int(k) for k, v in self.word_table.items()} + self.itos = {int(k): v for k, v in self.word_table.items()} + + self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] + + def refine_context(self, context): + context = context.strip().split("\n") + for c in range(len(context)): + context[c] = context[c].strip().strip("\u3000").strip("\r") + context = list(filter(lambda c: c != "", context)) + context = "\n" + ("\n".join(context)).strip() + if context == "": + context = "\n" + return context + + def sample_logits( + self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None + ): + # out[self.UNKNOWN_CHAR] = -float('Inf') + lastChar = int(x[-1]) + + probs = F.softmax(out, dim=-1) + + if self.charMode: + if self.itos[lastChar] == "\n": + top_p = top_p_newline + else: + top_p = top_p_usual + else: + top_p = top_p_usual + + if os.environ["RWKV_RUN_DEVICE"] == "cpu": + probs = probs.numpy() + sorted_probs = np.sort(probs)[::-1] + cumulative_probs = np.cumsum(sorted_probs) + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + probs = probs / np.sum(probs) + out = np.random.choice(a=len(probs), p=probs) + return out + else: + sorted_probs = torch.sort(probs, descending=True)[0] + cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + out = torch.multinomial(probs, num_samples=1)[0] + return out + + +def MaybeIsPrime(number): + if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): + return True + else: + return False + + +def FermatPrimalityTest(number): + if number > 1: + for time in range(3): + randomNumber = random.randint(2, number) - 1 + if pow(randomNumber, number - 1, number) != 1: + return False + return True + else: + return False + + +def MillerRabinPrimalityTest(number): + if number == 2: + return True + elif number == 1 or number % 2 == 0: + return False + oddPartOfNumber = number - 1 + timesTwoDividNumber = 0 + while oddPartOfNumber % 2 == 0: + oddPartOfNumber = oddPartOfNumber // 2 + timesTwoDividNumber = timesTwoDividNumber + 1 + + for time in range(3): + while True: + randomNumber = random.randint(2, number) - 1 + if randomNumber != 0 and randomNumber != 1: + break + + randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) + + if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): + iterationNumber = 1 + + while (iterationNumber <= timesTwoDividNumber - 1) and ( + randomNumberWithPower != number - 1 + ): + randomNumberWithPower = pow(randomNumberWithPower, 2, number) + iterationNumber = iterationNumber + 1 + if randomNumberWithPower != (number - 1): + return False + + return True diff --git a/finetune/lora/v6/train.py b/finetune/lora/v6/train.py new file mode 100644 index 0000000..e0079ea --- /dev/null +++ b/finetune/lora/v6/train.py @@ -0,0 +1,435 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import logging + +logging.basicConfig(level=logging.INFO) + +if __name__ == "__main__": + from argparse import ArgumentParser + from pytorch_lightning import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + import pytorch_lightning as pl + + rank_zero_info("########## work in progress ##########") + + parser = ArgumentParser() + + parser.add_argument("--load_model", default="", type=str) # full path, with .pth + parser.add_argument( + "--wandb", default="", type=str + ) # wandb project name. if "" then don't use wandb + parser.add_argument("--proj_dir", default="out", type=str) + parser.add_argument("--random_seed", default="-1", type=int) + + parser.add_argument("--data_file", default="", type=str) + parser.add_argument("--data_type", default="utf-8", type=str) + parser.add_argument( + "--vocab_size", default=0, type=int + ) # vocab_size = 0 means auto (for char-level LM and .txt data) + + parser.add_argument("--ctx_len", default=1024, type=int) + parser.add_argument( + "--epoch_steps", default=1000, type=int + ) # a mini "epoch" has [epoch_steps] steps + parser.add_argument( + "--epoch_count", default=500, type=int + ) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument( + "--epoch_begin", default=0, type=int + ) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument( + "--epoch_save", default=5, type=int + ) # save the model every [epoch_save] "epochs" + + parser.add_argument( + "--micro_bsz", default=12, type=int + ) # micro batch size (batch size per GPU) + parser.add_argument("--n_layer", default=6, type=int) + parser.add_argument("--n_embd", default=512, type=int) + parser.add_argument("--dim_att", default=0, type=int) + parser.add_argument("--dim_ffn", default=0, type=int) + parser.add_argument( + "--pre_ffn", default=0, type=int + ) # replace first att layer by ffn (sometimes better) + parser.add_argument("--head_qk", default=0, type=int) # my headQK trick + parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim + parser.add_argument( + "--tiny_att_layer", default=-999, type=int + ) # tiny attention @ which layer + + parser.add_argument( + "--lr_init", default=6e-4, type=float + ) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument("--lr_final", default=1e-5, type=float) + parser.add_argument( + "--warmup_steps", default=-1, type=int + ) # try 50 if you load a model + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument( + "--beta2", default=0.99, type=float + ) # use 0.999 when your model is close to convergence + parser.add_argument("--adam_eps", default=1e-8, type=float) + parser.add_argument( + "--grad_cp", default=0, type=int + ) # gradient checkpt: saves VRAM, but slower + parser.add_argument( + "--dropout", default=0, type=float + ) # try 0.01 / 0.02 / 0.05 / 0.1 + parser.add_argument( + "--weight_decay", default=0, type=float + ) # try 0.1 / 0.01 / 0.001 + parser.add_argument("--weight_decay_final", default=-1, type=float) + + parser.add_argument( + "--my_pile_version", default=1, type=int + ) # my special pile version + parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode + parser.add_argument( + "--my_pile_shift", default=-1, type=int + ) # my special pile mode - text shift + parser.add_argument("--my_pile_edecay", default=0, type=int) + parser.add_argument( + "--layerwise_lr", default=1, type=int + ) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument( + "--ds_bucket_mb", default=200, type=int + ) # deepspeed bucket size in MB. 200 seems enough + # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) + + parser.add_argument("--my_sample_len", default=0, type=int) + parser.add_argument("--my_ffn_shift", default=1, type=int) + parser.add_argument("--my_att_shift", default=1, type=int) + parser.add_argument( + "--head_size_a", default=64, type=int + ) # can try larger values for larger models + parser.add_argument("--head_size_divisor", default=8, type=int) + parser.add_argument("--my_pos_emb", default=0, type=int) + parser.add_argument("--load_partial", default=0, type=int) + parser.add_argument("--magic_prime", default=0, type=int) + parser.add_argument("--my_qa_mask", default=0, type=int) + parser.add_argument("--my_random_steps", default=0, type=int) + parser.add_argument("--my_testing", default="", type=str) + parser.add_argument("--my_exit", default=99999999, type=int) + parser.add_argument("--my_exit_tokens", default=0, type=int) + + # LORA + parser.add_argument("--emb", action="store_true") + parser.add_argument("--lora", action="store_true") + parser.add_argument("--lora_load", default="", type=str) + parser.add_argument("--lora_r", default=8, type=int) + parser.add_argument("--lora_alpha", default=32, type=float) + parser.add_argument("--lora_dropout", default=0.01, type=float) + parser.add_argument("--lora_parts", default="att,ln,time", type=str) + + if pl.__version__[0] == "2": + parser.add_argument("--accelerator", default="gpu", type=str) + parser.add_argument("--strategy", default="auto", type=str) + parser.add_argument("--devices", default=1, type=int) + parser.add_argument("--num_nodes", default=1, type=int) + parser.add_argument("--precision", default="fp16", type=str) + parser.add_argument("--accumulate_grad_batches", default=1, type=int) + else: + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args() + + ######################################################################################################## + + import os, warnings, math, datetime, sys, time + import numpy as np + import torch + from torch.utils.data import DataLoader + + if "deepspeed" in args.strategy: + import deepspeed + from pytorch_lightning import seed_everything + + if args.random_seed >= 0: + print( + f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" + * 3 + ) + seed_everything(args.random_seed) + + np.set_printoptions(precision=4, suppress=True, linewidth=200) + warnings.filterwarnings( + "ignore", ".*Consider increasing the value of the `num_workers` argument*" + ) + warnings.filterwarnings( + "ignore", ".*The progress bar already tracks a metric with the*" + ) + # os.environ["WDS_SHOW_SEED"] = "1" + + args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + args.enable_checkpointing = False + args.replace_sampler_ddp = False + args.logger = False + args.gradient_clip_val = 1.0 + args.num_sanity_val_steps = 0 + args.check_val_every_n_epoch = int(1e20) + args.log_every_n_steps = int(1e20) + args.max_epochs = args.epoch_count # -1 continue forever + args.betas = (args.beta1, args.beta2) + args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz + os.environ["RWKV_MY_TESTING"] = args.my_testing + os.environ["RWKV_CTXLEN"] = str(args.ctx_len) + os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a) + if args.dim_att <= 0: + args.dim_att = args.n_embd + if args.dim_ffn <= 0: + args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size + + if args.data_type == "wds_img": + args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" + args.proj_dir = f"{args.proj_dir}-{args.run_name}" + else: + args.run_name = ( + f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + ) + if not os.path.exists(args.proj_dir): + os.makedirs(args.proj_dir) + + if args.my_pile_stage > 0: + magic_prime_bak = args.magic_prime + + if args.my_pile_shift < 0: + args.my_pile_shift = 0 + + if magic_prime_bak > 0: + args.magic_prime = magic_prime_bak + if args.my_qa_mask == 2: + args.epoch_count = 2 * args.magic_prime // 40320 + else: + args.epoch_count = args.magic_prime // 40320 + + args.epoch_steps = 40320 // args.real_bsz + assert args.epoch_steps * args.real_bsz == 40320 + # if args.my_pile_stage == 2: + # assert args.lr_final == args.lr_init + if args.my_pile_stage >= 2: # find latest saved model + list_p = [] + for p in os.listdir(args.proj_dir): + if p.startswith("rwkv") and p.endswith(".pth"): + p = ((p.split("-"))[1].split("."))[0] + if p != "final": + if p == "init": + p = -1 + else: + p = int(p) + list_p += [p] + list_p.sort() + max_p = list_p[-1] + if len(list_p) > 1: + args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted + if max_p == -1: + args.load_model = f"{args.proj_dir}/rwkv-init.pth" + else: + args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" + if args.warmup_steps < 0: + if args.my_pile_stage == 2: + args.warmup_steps = 10 + else: + args.warmup_steps = 30 + args.epoch_begin = max_p + 1 + + samples_per_epoch = args.epoch_steps * args.real_bsz + tokens_per_epoch = samples_per_epoch * args.ctx_len + try: + deepspeed_version = deepspeed.__version__ + except: + deepspeed_version = None + pass + rank_zero_info( + f""" +############################################################################ +# +# RWKV-5 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} +# +# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} +# +# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1}, save every {args.epoch_save} epoch +# +# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens +# +# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len +# +# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} +# +# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer +# Found deepspeed {deepspeed_version}, recommend 0.7.0 (faster than newer versions) +# Found pytorch_lightning {pl.__version__}, recommend 1.9.5 +# +############################################################################ +""" + ) + rank_zero_info(str(vars(args)) + "\n") + + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "uint16"] + + if args.lr_final == 0 or args.lr_init == 0: + rank_zero_info( + "\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n" + ) + + assert args.precision in ["fp32", "tf32", "fp16", "bf16"] + os.environ["RWKV_FLOAT_MODE"] = args.precision + if args.precision == "fp32": + for i in range(10): + rank_zero_info( + "\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n" + ) + if args.precision == "fp16": + rank_zero_info( + "\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n" + ) + + os.environ["RWKV_JIT_ON"] = "0" + if "deepspeed_stage_3" in args.strategy: + os.environ["RWKV_JIT_ON"] = "0" + + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + if args.precision == "fp32": + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + else: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + if "32" in args.precision: + args.precision = 32 + elif args.precision == "fp16": + args.precision = 16 + else: + args.precision = "bf16" + + ######################################################################################################## + + from src.trainer import train_callback, generate_init_weight + from src.dataset import MyDataset + + train_data = MyDataset(args) + args.vocab_size = train_data.vocab_size + + from src.model import RWKV, LORA_CONFIG, LoraLinear + + model = RWKV(args) + + if args.lora: + assert args.lora_r > 0, "LoRA should have its `r` > 0" + LORA_CONFIG["r"] = args.lora_r + LORA_CONFIG["alpha"] = args.lora_alpha + LORA_CONFIG["dropout"] = args.lora_dropout + LORA_CONFIG["parts"] = set(str(args.lora_parts).split(",")) + enable_time_finetune = "time" in LORA_CONFIG["parts"] + enable_ln_finetune = "ln" in LORA_CONFIG["parts"] + model.requires_grad_(False) + for name, module in model.named_modules(): + + if any(n.startswith("lora_") for n, _ in module.named_parameters()): + print(f" LoRA additionally training module {name}") + for pname, param in module.named_parameters(): + param.requires_grad = "lora_" in pname + elif enable_ln_finetune and ".ln" in name: + print(f" LoRA additionally training module {name}") + for param in module.parameters(): + param.requires_grad = True + elif enable_time_finetune and any( + n.startswith("time") for n, _ in module.named_parameters() + ): + for pname, param in module.named_parameters(): + if pname.startswith("time"): + print(f" LoRA additionally training parameter {pname}") + param.requires_grad = True + + if ( + len(args.load_model) == 0 or args.my_pile_stage == 1 + ): # shall we build the initial weights? + init_weight_name = f"{args.proj_dir}/rwkv-init.pth" + generate_init_weight(model, init_weight_name) # save initial weights + args.load_model = init_weight_name + + rank_zero_info(f"########## Loading {args.load_model}... ##########") + try: + load_dict = torch.load(args.load_model, map_location="cpu") + load_keys = list(load_dict.keys()) + for k in load_keys: + if k.startswith("_forward_module."): + load_dict[k.replace("_forward_module.", "")] = load_dict[k] + del load_dict[k] + except: + rank_zero_info(f"Bad checkpoint {args.load_model}") + if args.my_pile_stage >= 2: # try again using another checkpoint + max_p = args.my_pile_prev_p + if max_p == -1: + args.load_model = f"{args.proj_dir}/rwkv-init.pth" + else: + args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" + args.epoch_begin = max_p + 1 + rank_zero_info(f"Trying {args.load_model}") + load_dict = torch.load(args.load_model, map_location="cpu") + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in model.state_dict(): + if k not in load_keys: + load_dict[k] = model.state_dict()[k] + model.load_state_dict(load_dict, strict=(not args.lora)) + if os.path.isfile(args.lora_load): + model.load_state_dict( + torch.load(args.lora_load, map_location="cpu"), strict=False + ) + + if pl.__version__[0] == "2": + trainer = Trainer( + accelerator=args.accelerator, + strategy=args.strategy, + devices=args.devices, + num_nodes=args.num_nodes, + precision=args.precision, + logger=args.logger, + callbacks=[train_callback(args)], + max_epochs=args.max_epochs, + check_val_every_n_epoch=args.check_val_every_n_epoch, + num_sanity_val_steps=args.num_sanity_val_steps, + log_every_n_steps=args.log_every_n_steps, + enable_checkpointing=args.enable_checkpointing, + accumulate_grad_batches=args.accumulate_grad_batches, + gradient_clip_val=args.gradient_clip_val, + ) + else: + trainer = Trainer.from_argparse_args( + args, + callbacks=[train_callback(args)], + ) + + if trainer.global_rank == 0: + for n in model.state_dict(): + shape = model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) + + # must set shuffle=False, persistent_workers=False (because worker is in another thread) + data_loader = DataLoader( + train_data, + shuffle=False, + pin_memory=True, + batch_size=args.micro_bsz, + num_workers=1, + persistent_workers=False, + drop_last=True, + ) + + trainer.fit(model, data_loader)