rwkv6 lora finetune support (https://github.com/JL-er/RWKV-LORA)
This commit is contained in:
parent
c6024520af
commit
333619839a
@ -52,9 +52,13 @@ for x in keys:
|
|||||||
if "time_maa" in x:
|
if "time_maa" in x:
|
||||||
version = max(6, version)
|
version = max(6, version)
|
||||||
|
|
||||||
|
params = f"--vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}"
|
||||||
|
|
||||||
if version <= expected_max_version:
|
if version <= expected_max_version:
|
||||||
|
if version == 6:
|
||||||
|
params += ' --my_testing "x060"'
|
||||||
print(
|
print(
|
||||||
f"v{int(version)}/train.py --vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}",
|
f"v{int(version)}/train.py {params}",
|
||||||
end="",
|
end="",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -53,7 +53,7 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
echo "loading $loadModel"
|
echo "loading $loadModel"
|
||||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 5.2)
|
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 6.0)
|
||||||
echo $modelInfo
|
echo $modelInfo
|
||||||
if [[ $modelInfo =~ "--n_layer" ]]; then
|
if [[ $modelInfo =~ "--n_layer" ]]; then
|
||||||
sudo rm -rf /root/.cache/torch_extensions
|
sudo rm -rf /root/.cache/torch_extensions
|
||||||
|
202
finetune/lora/v6/cuda/wkv5_cuda.cu
vendored
Normal file
202
finetune/lora/v6/cuda/wkv5_cuda.cu
vendored
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
#include <stdio.h>
|
||||||
|
#include <assert.h>
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
typedef at::BFloat16 bf16;
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
|
||||||
|
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
||||||
|
F *__restrict__ const _y)
|
||||||
|
{
|
||||||
|
const int b = blockIdx.x / H;
|
||||||
|
const int h = blockIdx.x % H;
|
||||||
|
const int i = threadIdx.x;
|
||||||
|
_w += h*_N_;
|
||||||
|
_u += h*_N_;
|
||||||
|
|
||||||
|
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
||||||
|
float state[_N_] = {0};
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
w[i] = _w[i];
|
||||||
|
u[i] = float(_u[i]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
r[i] = float(_r[t]);
|
||||||
|
k[i] = float(_k[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float v = float(_v[t]);
|
||||||
|
float y = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j+=4)
|
||||||
|
{
|
||||||
|
const float4& r_ = (float4&)(r[j]);
|
||||||
|
const float4& k_ = (float4&)(k[j]);
|
||||||
|
const float4& w_ = (float4&)(w[j]);
|
||||||
|
const float4& u_ = (float4&)(u[j]);
|
||||||
|
float4& s = (float4&)(state[j]);
|
||||||
|
float4 x;
|
||||||
|
|
||||||
|
x.x = k_.x * v;
|
||||||
|
x.y = k_.y * v;
|
||||||
|
x.z = k_.z * v;
|
||||||
|
x.w = k_.w * v;
|
||||||
|
|
||||||
|
y += r_.x * (u_.x * x.x + s.x);
|
||||||
|
y += r_.y * (u_.y * x.y + s.y);
|
||||||
|
y += r_.z * (u_.z * x.z + s.z);
|
||||||
|
y += r_.w * (u_.w * x.w + s.w);
|
||||||
|
|
||||||
|
s.x = s.x * w_.x + x.x;
|
||||||
|
s.y = s.y * w_.y + x.y;
|
||||||
|
s.z = s.z * w_.z + x.z;
|
||||||
|
s.w = s.w * w_.w + x.w;
|
||||||
|
}
|
||||||
|
_y[t] = F(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_backward(const int B, const int T, const int C, const int H,
|
||||||
|
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
|
||||||
|
F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
|
||||||
|
{
|
||||||
|
const int b = blockIdx.x / H;
|
||||||
|
const int h = blockIdx.x % H;
|
||||||
|
const int i = threadIdx.x;
|
||||||
|
_w += h*_N_;
|
||||||
|
_u += h*_N_;
|
||||||
|
__w += h*_N_;
|
||||||
|
|
||||||
|
__shared__ float w_[_N_], u_[_N_];
|
||||||
|
__shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
|
||||||
|
__syncthreads();
|
||||||
|
w_[i] = _w[i];
|
||||||
|
u_[i] = float(_u[i]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float w = w_[i];
|
||||||
|
const float ww = __w[i];
|
||||||
|
const float u = u_[i];
|
||||||
|
|
||||||
|
float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
|
||||||
|
|
||||||
|
float gw = 0, gu = 0;
|
||||||
|
const int t000 = b*T*C + h*_N_ + i;
|
||||||
|
const int t111 = (b+1)*T*C + h*_N_ + i;
|
||||||
|
const int t222 = t111 - 2*C;
|
||||||
|
|
||||||
|
for (int t = t000; t < t111; t += C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
v[i] = float(_v[t]);
|
||||||
|
gy[i] = float(_gy[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float k = float(_k[t]);
|
||||||
|
float gr = 0, gu_ = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = state[j];
|
||||||
|
float x = k * v[j];
|
||||||
|
|
||||||
|
gr += (u * x + s) * gy[j];
|
||||||
|
gu_ += x * gy[j];
|
||||||
|
s = s * w + x;
|
||||||
|
}
|
||||||
|
_gr[t] = F(gr);
|
||||||
|
gu += float(_r[t]) * gu_;
|
||||||
|
}
|
||||||
|
_gu[b*C + h*_N_ + i] = F(gu);
|
||||||
|
|
||||||
|
for (int t = t000; t < t222; t += C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
v[i] = float(_v[t]);
|
||||||
|
gy[i] = float(_gy[t + 2*C]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float k = float(_k[t]);
|
||||||
|
float gw_ = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = saaaa[j];
|
||||||
|
float& s2 = sbbbb[j];
|
||||||
|
float x = k * v[j];
|
||||||
|
|
||||||
|
float tmp = w * (x + s);
|
||||||
|
s = tmp;
|
||||||
|
s2 = tmp + w * s2;
|
||||||
|
gw_ += s2 * gy[j];
|
||||||
|
}
|
||||||
|
gw += float(_r[t + 2*C]) * gw_;
|
||||||
|
}
|
||||||
|
_gw[b*C + h*_N_ + i] = F(ww * gw);
|
||||||
|
|
||||||
|
for (int t = t111 - C; t >= t000; t -= C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
v[i] = float(_v[t]);
|
||||||
|
gy[i] = float(_gy[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float rr = float(_r[t]);
|
||||||
|
float gk = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = scccc[j];
|
||||||
|
float x = rr * gy[j];
|
||||||
|
|
||||||
|
gk += (u * x + s) * v[j];
|
||||||
|
s = x + s * w;
|
||||||
|
}
|
||||||
|
_gk[t] = F(gk);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = t111 - C; t >= t000; t -= C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
r[i] = float(_r[t]);
|
||||||
|
k[i] = float(_k[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float gyy = float(_gy[t]);
|
||||||
|
float gv = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = sdddd[j];
|
||||||
|
float x = gyy * r[j];
|
||||||
|
|
||||||
|
gv += (u_[j] * x + s) * k[j];
|
||||||
|
s = x + s * w_[j];
|
||||||
|
}
|
||||||
|
_gv[t] = F(gv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
||||||
|
{
|
||||||
|
assert(H*_N_ == C);
|
||||||
|
assert(_N_%4 == 0);
|
||||||
|
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
|
||||||
|
{
|
||||||
|
assert(H*_N_ == C);
|
||||||
|
assert(_N_%4 == 0);
|
||||||
|
kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
|
||||||
|
}
|
22
finetune/lora/v6/cuda/wkv5_op.cpp
vendored
Normal file
22
finetune/lora/v6/cuda/wkv5_op.cpp
vendored
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
typedef at::BFloat16 bf16;
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
||||||
|
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
|
||||||
|
|
||||||
|
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||||
|
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||||
|
}
|
||||||
|
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
|
||||||
|
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
|
||||||
|
}
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &forward, "wkv5 forward");
|
||||||
|
m.def("backward", &backward, "wkv5 backward");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY(wkv5, m) {
|
||||||
|
m.def("forward", forward);
|
||||||
|
m.def("backward", backward);
|
||||||
|
}
|
242
finetune/lora/v6/cuda/wkv6_cuda.cu
vendored
Normal file
242
finetune/lora/v6/cuda/wkv6_cuda.cu
vendored
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
#include <stdio.h>
|
||||||
|
#include <assert.h>
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
typedef at::BFloat16 bf16;
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
|
||||||
|
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
||||||
|
F *__restrict__ const _y)
|
||||||
|
{
|
||||||
|
const int b = blockIdx.x / H;
|
||||||
|
const int h = blockIdx.x % H;
|
||||||
|
const int i = threadIdx.x;
|
||||||
|
_u += h*_N_;
|
||||||
|
|
||||||
|
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
||||||
|
float state[_N_] = {0};
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
u[i] = float(_u[i]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
w[i] = exp(_w[t]);
|
||||||
|
r[i] = float(_r[t]);
|
||||||
|
k[i] = float(_k[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float v = float(_v[t]);
|
||||||
|
float y = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j+=4)
|
||||||
|
{
|
||||||
|
const float4& r_ = (float4&)(r[j]);
|
||||||
|
const float4& k_ = (float4&)(k[j]);
|
||||||
|
const float4& w_ = (float4&)(w[j]);
|
||||||
|
const float4& u_ = (float4&)(u[j]);
|
||||||
|
float4& s = (float4&)(state[j]);
|
||||||
|
float4 x;
|
||||||
|
|
||||||
|
x.x = k_.x * v;
|
||||||
|
x.y = k_.y * v;
|
||||||
|
x.z = k_.z * v;
|
||||||
|
x.w = k_.w * v;
|
||||||
|
|
||||||
|
y += r_.x * (u_.x * x.x + s.x);
|
||||||
|
y += r_.y * (u_.y * x.y + s.y);
|
||||||
|
y += r_.z * (u_.z * x.z + s.z);
|
||||||
|
y += r_.w * (u_.w * x.w + s.w);
|
||||||
|
|
||||||
|
s.x = s.x * w_.x + x.x;
|
||||||
|
s.y = s.y * w_.y + x.y;
|
||||||
|
s.z = s.z * w_.z + x.z;
|
||||||
|
s.w = s.w * w_.w + x.w;
|
||||||
|
}
|
||||||
|
_y[t] = F(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_backward_111(const int B, const int T, const int C, const int H,
|
||||||
|
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
|
||||||
|
F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu)
|
||||||
|
{
|
||||||
|
const int b = blockIdx.x / H;
|
||||||
|
const int h = blockIdx.x % H;
|
||||||
|
const int i = threadIdx.x;
|
||||||
|
_u += h*_N_;
|
||||||
|
|
||||||
|
__shared__ float u_[_N_];
|
||||||
|
__shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
|
||||||
|
__syncthreads();
|
||||||
|
u_[i] = float(_u[i]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float u = u_[i];
|
||||||
|
|
||||||
|
float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
|
||||||
|
|
||||||
|
const int t_0 = b*T*C + h*_N_ + i;
|
||||||
|
const int t_T_1 = t_0 + (T-1)*C;
|
||||||
|
const int t_T = t_0 + T*C;
|
||||||
|
|
||||||
|
float gu = 0;
|
||||||
|
for (int t = t_0; t < t_T; t += C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
v[i] = float(_v[t]);
|
||||||
|
gy[i] = float(_gy[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float k = float(_k[t]);
|
||||||
|
const float w = exp(_w[t]);
|
||||||
|
float gr = 0, gu_ = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = state[j];
|
||||||
|
float x = k * v[j];
|
||||||
|
|
||||||
|
gr += (u * x + s) * gy[j];
|
||||||
|
gu_ += x * gy[j];
|
||||||
|
s = s * w + x;
|
||||||
|
}
|
||||||
|
_gr[t] = F(gr);
|
||||||
|
gu += float(_r[t]) * gu_;
|
||||||
|
}
|
||||||
|
_gu[b*C + h*_N_ + i] = F(gu);
|
||||||
|
|
||||||
|
for (int t = t_T_1; t >= t_0; t -= C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
v[i] = float(_v[t]);
|
||||||
|
gy[i] = float(_gy[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float rr = float(_r[t]);
|
||||||
|
const float w = exp(_w[t]);
|
||||||
|
float gk = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = scccc[j];
|
||||||
|
float x = rr * gy[j];
|
||||||
|
|
||||||
|
gk += (u * x + s) * v[j];
|
||||||
|
s = x + s * w;
|
||||||
|
}
|
||||||
|
_gk[t] = F(gk);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int t = t_T_1; t >= t_0; t -= C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
r[i] = float(_r[t]);
|
||||||
|
k[i] = float(_k[t]);
|
||||||
|
w_[i] = exp(_w[t]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float gyy = float(_gy[t]);
|
||||||
|
float gv = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = sdddd[j];
|
||||||
|
float x = gyy * r[j];
|
||||||
|
|
||||||
|
gv += (u_[j] * x + s) * k[j];
|
||||||
|
s = x + s * w_[j];
|
||||||
|
}
|
||||||
|
_gv[t] = F(gv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void kernel_backward_222(const int B, const int T, const int C, const int H,
|
||||||
|
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
|
||||||
|
F *__restrict__ const _gw)
|
||||||
|
{
|
||||||
|
const int b = blockIdx.x / H;
|
||||||
|
const int h = blockIdx.x % H;
|
||||||
|
const int i = threadIdx.x;
|
||||||
|
|
||||||
|
__shared__ float v[_N_], gy[_N_];
|
||||||
|
float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
|
||||||
|
|
||||||
|
const int t_0 = b*T*C + h*_N_ + i;
|
||||||
|
const int t_1 = t_0 + C;
|
||||||
|
const int t_2 = t_0 + 2*C;
|
||||||
|
const int t_T_1 = t_0 + (T-1)*C;
|
||||||
|
|
||||||
|
for (int t = t_T_1; t > t_1; t -= C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
gy[i] = float(_gy[t]);
|
||||||
|
v[i] = float(_v[t-2*C]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float r = float(_r[t]);
|
||||||
|
const float w = exp(_w[t-C]);
|
||||||
|
float sum = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = saaaa[j];
|
||||||
|
float x = r * gy[j];
|
||||||
|
s = (s + x) * w;
|
||||||
|
sum += s * v[j];
|
||||||
|
}
|
||||||
|
sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
|
||||||
|
}
|
||||||
|
|
||||||
|
float sss = sbbbb[0];
|
||||||
|
_gw[t_0] = 0;
|
||||||
|
_gw[t_1] = F(sss * _w[t_1]);
|
||||||
|
|
||||||
|
for (int t = t_2; t < t_T_1; t += C)
|
||||||
|
{
|
||||||
|
__syncthreads();
|
||||||
|
gy[i] = float(_gy[t]);
|
||||||
|
v[i] = float(_v[t-2*C]);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float w = exp(_w[t-C]);
|
||||||
|
const float k = float(_k[t-2*C]);
|
||||||
|
float sum = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < _N_; j++)
|
||||||
|
{
|
||||||
|
float& s = scccc[j];
|
||||||
|
float x = k * v[j];
|
||||||
|
s = (s + x) * w;
|
||||||
|
sum += s * gy[j];
|
||||||
|
}
|
||||||
|
sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
|
||||||
|
_gw[t] = F(sss * _w[t]);
|
||||||
|
}
|
||||||
|
_gw[t_T_1] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
||||||
|
{
|
||||||
|
assert(H*_N_ == C);
|
||||||
|
assert(_N_%4 == 0);
|
||||||
|
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
|
||||||
|
{
|
||||||
|
assert(H*_N_ == C);
|
||||||
|
assert(_N_%4 == 0);
|
||||||
|
kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu);
|
||||||
|
kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gw);
|
||||||
|
}
|
22
finetune/lora/v6/cuda/wkv6_op.cpp
vendored
Normal file
22
finetune/lora/v6/cuda/wkv6_op.cpp
vendored
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
typedef at::BFloat16 bf16;
|
||||||
|
|
||||||
|
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
||||||
|
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
|
||||||
|
|
||||||
|
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||||
|
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||||
|
}
|
||||||
|
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
|
||||||
|
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
|
||||||
|
}
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &forward, "wkv6 forward");
|
||||||
|
m.def("backward", &backward, "wkv6 backward");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY(wkv6, m) {
|
||||||
|
m.def("forward", forward);
|
||||||
|
m.def("backward", backward);
|
||||||
|
}
|
0
finetune/lora/v6/src/__init__.py
vendored
Normal file
0
finetune/lora/v6/src/__init__.py
vendored
Normal file
303
finetune/lora/v6/src/binidx.py
vendored
Normal file
303
finetune/lora/v6/src/binidx.py
vendored
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
from lib2to3.pgen2 import token
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import shutil
|
||||||
|
import struct
|
||||||
|
from functools import lru_cache
|
||||||
|
from itertools import accumulate
|
||||||
|
|
||||||
|
|
||||||
|
def print_rank_0(*message):
|
||||||
|
pass
|
||||||
|
# """If distributed is initialized print only on rank 0."""
|
||||||
|
# if torch.distributed.is_initialized():
|
||||||
|
# if torch.distributed.get_rank() == 0:
|
||||||
|
# print(*message, flush=True)
|
||||||
|
# else:
|
||||||
|
# print(*message, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _warmup_mmap_file(path):
|
||||||
|
pass
|
||||||
|
# with open(path, "rb") as stream:
|
||||||
|
# while stream.read(100 * 1024 * 1024):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
|
dtypes = {
|
||||||
|
1: np.uint8,
|
||||||
|
2: np.int8,
|
||||||
|
3: np.int16,
|
||||||
|
4: np.int32,
|
||||||
|
5: np.int64,
|
||||||
|
6: float,
|
||||||
|
7: np.double,
|
||||||
|
8: np.uint16,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def code(dtype):
|
||||||
|
for k in dtypes.keys():
|
||||||
|
if dtypes[k] == dtype:
|
||||||
|
return k
|
||||||
|
raise ValueError(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def index_file_path(prefix_path):
|
||||||
|
return prefix_path + ".idx"
|
||||||
|
|
||||||
|
|
||||||
|
def data_file_path(prefix_path):
|
||||||
|
return prefix_path + ".bin"
|
||||||
|
|
||||||
|
|
||||||
|
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||||
|
class Index(object):
|
||||||
|
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def writer(cls, path, dtype):
|
||||||
|
class _Writer(object):
|
||||||
|
def __enter__(self):
|
||||||
|
self._file = open(path, "wb")
|
||||||
|
|
||||||
|
# Write Magic string so we can check the file format then opening it again.
|
||||||
|
self._file.write(cls._HDR_MAGIC)
|
||||||
|
# Write version number
|
||||||
|
# Little endian unsigned 64 Bit integer
|
||||||
|
self._file.write(struct.pack("<Q", 1))
|
||||||
|
# Little endian unsigned 8 Bit integer
|
||||||
|
self._file.write(struct.pack("<B", code(dtype)))
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_pointers(sizes):
|
||||||
|
dtype_size = dtype().itemsize
|
||||||
|
address = 0
|
||||||
|
pointers = []
|
||||||
|
|
||||||
|
for size in sizes:
|
||||||
|
pointers.append(address)
|
||||||
|
address += size * dtype_size
|
||||||
|
|
||||||
|
return pointers
|
||||||
|
|
||||||
|
def write(self, sizes, doc_idx):
|
||||||
|
pointers = self._get_pointers(sizes)
|
||||||
|
|
||||||
|
# Little endian unsigned 64 Bit integer
|
||||||
|
self._file.write(struct.pack("<Q", len(sizes)))
|
||||||
|
# Little endian unsigned 64 Bit integer
|
||||||
|
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||||
|
|
||||||
|
sizes = np.array(sizes, dtype=np.int32)
|
||||||
|
self._file.write(sizes.tobytes(order="C"))
|
||||||
|
del sizes
|
||||||
|
|
||||||
|
pointers = np.array(pointers, dtype=np.int64)
|
||||||
|
self._file.write(pointers.tobytes(order="C"))
|
||||||
|
del pointers
|
||||||
|
|
||||||
|
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||||
|
self._file.write(doc_idx.tobytes(order="C"))
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self._file.close()
|
||||||
|
|
||||||
|
return _Writer()
|
||||||
|
|
||||||
|
def __init__(self, path, skip_warmup=False):
|
||||||
|
with open(path, "rb") as stream:
|
||||||
|
magic_test = stream.read(9)
|
||||||
|
assert self._HDR_MAGIC == magic_test, (
|
||||||
|
"Index file doesn't match expected format. "
|
||||||
|
"Make sure that --dataset-impl is configured properly."
|
||||||
|
)
|
||||||
|
# Little endian unsigned 64 Bit integer
|
||||||
|
version = struct.unpack("<Q", stream.read(8))
|
||||||
|
assert (1,) == version
|
||||||
|
|
||||||
|
# Little endian unsigned 8 Bit integer
|
||||||
|
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||||
|
self._dtype = dtypes[dtype_code]
|
||||||
|
self._dtype_size = self._dtype().itemsize
|
||||||
|
|
||||||
|
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||||
|
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||||
|
offset = stream.tell()
|
||||||
|
|
||||||
|
if not skip_warmup:
|
||||||
|
print_rank_0(" warming up index mmap file...")
|
||||||
|
_warmup_mmap_file(path)
|
||||||
|
|
||||||
|
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||||
|
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||||
|
print_rank_0(" reading sizes...")
|
||||||
|
self._sizes = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||||
|
)
|
||||||
|
print_rank_0(" reading pointers...")
|
||||||
|
self._pointers = np.frombuffer(
|
||||||
|
self._bin_buffer,
|
||||||
|
dtype=np.int64,
|
||||||
|
count=self._len,
|
||||||
|
offset=offset + self._sizes.nbytes,
|
||||||
|
)
|
||||||
|
print_rank_0(" reading document index...")
|
||||||
|
self._doc_idx = np.frombuffer(
|
||||||
|
self._bin_buffer,
|
||||||
|
dtype=np.int64,
|
||||||
|
count=self._doc_count,
|
||||||
|
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self._bin_buffer_mmap._mmap.close()
|
||||||
|
del self._bin_buffer_mmap
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self._dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sizes(self):
|
||||||
|
return self._sizes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def doc_idx(self):
|
||||||
|
return self._doc_idx
|
||||||
|
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self._pointers[i], self._sizes[i]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._len
|
||||||
|
|
||||||
|
def __init__(self, path, skip_warmup=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._path = None
|
||||||
|
self._index = None
|
||||||
|
self._bin_buffer = None
|
||||||
|
|
||||||
|
self._do_init(path, skip_warmup)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self._do_init(state)
|
||||||
|
|
||||||
|
def _do_init(self, path, skip_warmup):
|
||||||
|
self._path = path
|
||||||
|
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||||
|
|
||||||
|
if not skip_warmup:
|
||||||
|
print_rank_0(" warming up data mmap file...")
|
||||||
|
_warmup_mmap_file(data_file_path(self._path))
|
||||||
|
print_rank_0(" creating numpy buffer of mmap...")
|
||||||
|
self._bin_buffer_mmap = np.memmap(
|
||||||
|
data_file_path(self._path), mode="r", order="C"
|
||||||
|
)
|
||||||
|
print_rank_0(" creating memory view of numpy buffer...")
|
||||||
|
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self._bin_buffer_mmap._mmap.close()
|
||||||
|
del self._bin_buffer_mmap
|
||||||
|
del self._index
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._index)
|
||||||
|
|
||||||
|
# @lru_cache(maxsize=8)
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if isinstance(idx, int):
|
||||||
|
ptr, size = self._index[idx]
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||||
|
)
|
||||||
|
return np_array
|
||||||
|
elif isinstance(idx, slice):
|
||||||
|
start, stop, step = idx.indices(len(self))
|
||||||
|
if step != 1:
|
||||||
|
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||||
|
ptr = self._index._pointers[start]
|
||||||
|
sizes = self._index._sizes[idx]
|
||||||
|
offsets = list(accumulate(sizes))
|
||||||
|
total_size = sum(sizes)
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||||
|
)
|
||||||
|
sents = np.split(np_array, offsets[:-1])
|
||||||
|
return sents
|
||||||
|
|
||||||
|
def get(self, idx, offset=0, length=None):
|
||||||
|
"""Retrieves a single item from the dataset with the option to only
|
||||||
|
return a portion of the item.
|
||||||
|
|
||||||
|
get(idx) is the same as [idx] but get() does not support slicing.
|
||||||
|
"""
|
||||||
|
ptr, size = self._index[idx]
|
||||||
|
if length is None:
|
||||||
|
length = size - offset
|
||||||
|
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||||
|
)
|
||||||
|
return np_array
|
||||||
|
|
||||||
|
def pad(self, idx, length=None):
|
||||||
|
ptr, size = self._index[idx]
|
||||||
|
try:
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||||
|
)
|
||||||
|
ptr0, _ = self._index[0]
|
||||||
|
np_array0 = np.frombuffer(
|
||||||
|
self._bin_buffer,
|
||||||
|
dtype=self._index.dtype,
|
||||||
|
count=length - size,
|
||||||
|
offset=ptr0,
|
||||||
|
)
|
||||||
|
np_array = np.append(np_array, np_array0)
|
||||||
|
return np_array
|
||||||
|
|
||||||
|
def only(self, idx):
|
||||||
|
ptr, size = self._index[idx]
|
||||||
|
np_array = np.frombuffer(
|
||||||
|
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||||
|
)
|
||||||
|
|
||||||
|
return np_array
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sizes(self):
|
||||||
|
return self._index.sizes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def doc_idx(self):
|
||||||
|
return self._index.doc_idx
|
||||||
|
|
||||||
|
def get_doc_idx(self):
|
||||||
|
return self._index._doc_idx
|
||||||
|
|
||||||
|
def set_doc_idx(self, doc_idx_):
|
||||||
|
self._index._doc_idx = doc_idx_
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_prefetch(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def exists(path):
|
||||||
|
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||||
|
data_file_path(path)
|
||||||
|
)
|
242
finetune/lora/v6/src/dataset.py
vendored
Normal file
242
finetune/lora/v6/src/dataset.py
vendored
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import json, math, random, os, sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info
|
||||||
|
from .binidx import MMapIndexedDataset
|
||||||
|
from .utils import MaybeIsPrime
|
||||||
|
|
||||||
|
|
||||||
|
class MyDataset(Dataset):
|
||||||
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
if args.data_type == "binidx":
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
rank_zero_info(
|
||||||
|
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.my_pile_version == 1:
|
||||||
|
self.data = MMapIndexedDataset(args.data_file)
|
||||||
|
self.data_size = (
|
||||||
|
len(self.data._bin_buffer) // self.data._index._dtype_size
|
||||||
|
)
|
||||||
|
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||||
|
elif args.my_pile_version == 2:
|
||||||
|
data_list = (
|
||||||
|
open(args.data_file, "r", encoding="utf-8")
|
||||||
|
.read()
|
||||||
|
.strip()
|
||||||
|
.split("\n")
|
||||||
|
)
|
||||||
|
data_list = [i.strip().split(" ") for i in data_list]
|
||||||
|
self.data = []
|
||||||
|
self.data_size = int(data_list[-1][-1])
|
||||||
|
rank_zero_info(f"Data has {self.data_size} chunks.")
|
||||||
|
for d in data_list:
|
||||||
|
data = MMapIndexedDataset(d[0])
|
||||||
|
data_size = len(data._bin_buffer) // data._index._dtype_size
|
||||||
|
assert (data_size - args.ctx_len) == int(d[1])
|
||||||
|
self.data += [[int(d[-1]), int(d[1]), data]]
|
||||||
|
# rank_zero_info(self.data)
|
||||||
|
|
||||||
|
if args.my_qa_mask > 0:
|
||||||
|
# self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
|
||||||
|
self.data_pile = MMapIndexedDataset(
|
||||||
|
"/fsx/pile_deduped/pile_0.87_deduped_text_document"
|
||||||
|
)
|
||||||
|
self.data_pile_size = (
|
||||||
|
len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_pile = None
|
||||||
|
self.data_pile_size = 0
|
||||||
|
|
||||||
|
if args.my_pile_stage > 0:
|
||||||
|
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
||||||
|
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||||
|
assert self.samples_per_epoch == 40320
|
||||||
|
rank_zero_info(
|
||||||
|
f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########"
|
||||||
|
)
|
||||||
|
dataset_slot = self.data_size // args.ctx_len
|
||||||
|
if args.my_pile_stage != 4:
|
||||||
|
assert MaybeIsPrime(args.magic_prime)
|
||||||
|
assert args.magic_prime % 3 == 2
|
||||||
|
assert (
|
||||||
|
args.magic_prime / dataset_slot > 0.99
|
||||||
|
and args.magic_prime / dataset_slot <= 1
|
||||||
|
)
|
||||||
|
elif args.data_type == "numpy":
|
||||||
|
self.data = np.load(args.data_file).astype("int")
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
rank_zero_info(
|
||||||
|
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||||
|
)
|
||||||
|
self.data_size = len(self.data)
|
||||||
|
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||||
|
elif args.data_type == "uint16":
|
||||||
|
self.data = (
|
||||||
|
np.fromfile(args.data_file, dtype=np.uint16)
|
||||||
|
.astype("int32")
|
||||||
|
.reshape(-1, args.my_sample_len)
|
||||||
|
)
|
||||||
|
self.vocab_size = args.vocab_size
|
||||||
|
rank_zero_info(
|
||||||
|
f"Current vocab size = {self.vocab_size} (make sure it's correct)"
|
||||||
|
)
|
||||||
|
self.data_size = self.data.shape[0]
|
||||||
|
rank_zero_info(f"Data has {self.data_size} samples.")
|
||||||
|
else:
|
||||||
|
if args.data_type == "dummy":
|
||||||
|
rank_zero_info("Building dummy data...")
|
||||||
|
self.data = ""
|
||||||
|
for i in range(100000):
|
||||||
|
aa = (i) % 10000
|
||||||
|
bb = (i * i) % 10000
|
||||||
|
cc = aa + bb
|
||||||
|
self.data += f".{aa}+{bb}={cc}."
|
||||||
|
else:
|
||||||
|
self.data = open(args.data_file, "r", encoding=args.data_type).read()
|
||||||
|
rank_zero_info("Building token list...")
|
||||||
|
unique = sorted(list(set(self.data)))
|
||||||
|
self.vocab_size = len(unique)
|
||||||
|
# rank_zero_info()
|
||||||
|
# for u in unique:
|
||||||
|
# print(u, end=' ')
|
||||||
|
# rank_zero_info('\n\n')
|
||||||
|
xx = 0
|
||||||
|
xxObj = {}
|
||||||
|
for u in unique:
|
||||||
|
xxObj[xx] = u
|
||||||
|
xx += 1
|
||||||
|
with open(
|
||||||
|
f"{args.proj_dir}/vocab.json", "w", encoding="utf-8"
|
||||||
|
) as vocab_file:
|
||||||
|
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||||
|
self.data_size = len(self.data)
|
||||||
|
rank_zero_info(
|
||||||
|
f"Data has {self.data_size} tokens, {self.vocab_size} vocab size."
|
||||||
|
)
|
||||||
|
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||||
|
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.args.epoch_steps * self.args.micro_bsz
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
args = self.args
|
||||||
|
rank = self.global_rank
|
||||||
|
epoch = self.real_epoch
|
||||||
|
world_size = self.world_size
|
||||||
|
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
||||||
|
|
||||||
|
if args.data_type == "uint16":
|
||||||
|
i = np.random.randint(0, self.data_size - 1)
|
||||||
|
dix = self.data[i]
|
||||||
|
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||||
|
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||||
|
else:
|
||||||
|
ctx_len = args.ctx_len
|
||||||
|
req_len = ctx_len + 1
|
||||||
|
magic_prime = args.magic_prime
|
||||||
|
data = self.data
|
||||||
|
|
||||||
|
if args.my_pile_stage > 0:
|
||||||
|
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
|
||||||
|
|
||||||
|
if args.my_qa_mask > 0:
|
||||||
|
ii_orig = ii
|
||||||
|
if ii % 2 == 0:
|
||||||
|
ii = -1
|
||||||
|
data = self.data_pile
|
||||||
|
else:
|
||||||
|
ii = ii // 2
|
||||||
|
if data == self.data_pile:
|
||||||
|
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||||
|
else:
|
||||||
|
if args.my_pile_stage == 4 or ii < args.my_random_steps:
|
||||||
|
# cheat: pick a random spot in dataset
|
||||||
|
if args.my_pile_version == 1:
|
||||||
|
i = np.random.randint(0, self.data_size - req_len)
|
||||||
|
else:
|
||||||
|
i = np.random.randint(0, self.data_size)
|
||||||
|
else:
|
||||||
|
ii = ii - args.my_random_steps
|
||||||
|
factor = (math.sqrt(5) - 1) / 2
|
||||||
|
factor = int(magic_prime * factor)
|
||||||
|
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
|
||||||
|
i = i + args.my_pile_shift
|
||||||
|
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
|
||||||
|
else:
|
||||||
|
# cheat: pick a random spot in dataset
|
||||||
|
i = np.random.randint(0, self.data_size - req_len)
|
||||||
|
|
||||||
|
if args.data_type == "binidx":
|
||||||
|
if args.my_pile_version == 1:
|
||||||
|
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
||||||
|
# dix = data.pad(idx=idx, length=req_len).astype(int)
|
||||||
|
else:
|
||||||
|
# self.data : cutoff, chunk_count, data
|
||||||
|
for j in range(len(data)):
|
||||||
|
if i < data[j][0]:
|
||||||
|
ii = i
|
||||||
|
i = (i - (data[j - 1][0] if j > 0 else 0)) % data[j][1]
|
||||||
|
dix = (
|
||||||
|
data[j][2]
|
||||||
|
.get(idx=0, offset=i, length=req_len)
|
||||||
|
.astype(int)
|
||||||
|
)
|
||||||
|
# print(ii, j, i)
|
||||||
|
break
|
||||||
|
elif args.data_type == "numpy":
|
||||||
|
dix = data[i : i + req_len]
|
||||||
|
else:
|
||||||
|
dix = [self.stoi[s] for s in data[i : i + req_len]]
|
||||||
|
|
||||||
|
if args.my_qa_mask == 1:
|
||||||
|
if data == self.data_pile:
|
||||||
|
z = [1] * ctx_len
|
||||||
|
else:
|
||||||
|
z = [0] * ctx_len
|
||||||
|
z_sum = 0
|
||||||
|
isGood = False
|
||||||
|
for i in range(3, ctx_len):
|
||||||
|
if (
|
||||||
|
dix[i] == 27
|
||||||
|
and dix[i - 1] == 34
|
||||||
|
and dix[i - 2] == 187
|
||||||
|
and dix[i - 3] == 187
|
||||||
|
):
|
||||||
|
isGood = True
|
||||||
|
if dix[i] == 0:
|
||||||
|
isGood = False
|
||||||
|
if isGood:
|
||||||
|
z[i] = 1
|
||||||
|
z_sum += 1
|
||||||
|
if z_sum == 0:
|
||||||
|
z = [1] * ctx_len
|
||||||
|
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||||
|
dix = self.data_pile.get(
|
||||||
|
idx=0, offset=i, length=req_len
|
||||||
|
).astype(int)
|
||||||
|
z = torch.tensor(z, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||||
|
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||||
|
|
||||||
|
# if ii_orig < 50:
|
||||||
|
# # if rank == 1:
|
||||||
|
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
|
||||||
|
# else:
|
||||||
|
# exit(0)
|
||||||
|
|
||||||
|
if args.my_qa_mask == 1:
|
||||||
|
return x, y, z
|
||||||
|
|
||||||
|
return x, y
|
1086
finetune/lora/v6/src/model.py
vendored
Normal file
1086
finetune/lora/v6/src/model.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
310
finetune/lora/v6/src/trainer.py
vendored
Normal file
310
finetune/lora/v6/src/trainer.py
vendored
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
import os, math, time, datetime, subprocess
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||||
|
from .model import LORA_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def my_save(args, trainer, dd, ff):
|
||||||
|
if "14b-run1" in ff:
|
||||||
|
fn = ff.split("/")[-1]
|
||||||
|
fff = "/dev/shm/" + fn
|
||||||
|
torch.save(dd, fff)
|
||||||
|
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
||||||
|
elif ("world/14b" in ff) or ("world/7b" in ff):
|
||||||
|
aa = ff.split("/")[1]
|
||||||
|
fn = ff.split("/")[-1]
|
||||||
|
fff = f"/dev/shm/{aa}-{fn}"
|
||||||
|
torch.save(dd, fff)
|
||||||
|
subprocess.Popen(
|
||||||
|
f" aws s3 mv {fff} s3://rwkv-world/{aa}-{fn} --quiet", shell=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "deepspeed_stage_3" in args.strategy:
|
||||||
|
trainer.save_checkpoint(ff, weights_only=True)
|
||||||
|
else:
|
||||||
|
torch.save(dd, ff)
|
||||||
|
|
||||||
|
|
||||||
|
class train_callback(pl.Callback):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||||
|
args = self.args
|
||||||
|
# if args.cuda_cleanup > 0:
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||||
|
|
||||||
|
# LR schedule
|
||||||
|
w_step = args.warmup_steps
|
||||||
|
if args.lr_final == args.lr_init or args.epoch_count == 0:
|
||||||
|
lr = args.lr_init
|
||||||
|
else:
|
||||||
|
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
|
||||||
|
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
|
||||||
|
progress = (decay_step - w_step + 1) / (decay_total - w_step)
|
||||||
|
progress = min(1, max(0, progress))
|
||||||
|
|
||||||
|
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
||||||
|
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
||||||
|
else: # exp decay
|
||||||
|
lr = args.lr_init * math.exp(
|
||||||
|
math.log(args.lr_final / args.lr_init) * pow(progress, 1)
|
||||||
|
)
|
||||||
|
# if trainer.is_global_zero:
|
||||||
|
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
|
||||||
|
|
||||||
|
if args.my_exit_tokens != 0: # cosine decay
|
||||||
|
real_tokens = real_step * args.ctx_len * args.real_bsz
|
||||||
|
warmup_tokens = w_step * args.ctx_len * args.real_bsz
|
||||||
|
progress = (real_tokens - warmup_tokens) / (
|
||||||
|
abs(args.my_exit_tokens) - warmup_tokens
|
||||||
|
)
|
||||||
|
progress = max(0, min(1, progress))
|
||||||
|
lr_final_factor = args.lr_final / args.lr_init
|
||||||
|
lr_mult = (0.5 + lr_final_factor / 2) + (
|
||||||
|
0.5 - lr_final_factor / 2
|
||||||
|
) * math.cos(math.pi * progress)
|
||||||
|
if args.my_exit_tokens > 0:
|
||||||
|
lr = args.lr_init * lr_mult
|
||||||
|
else:
|
||||||
|
lr = (lr + args.lr_init * lr_mult) / 2
|
||||||
|
if progress >= 1:
|
||||||
|
if (trainer.is_global_zero) or ("deepspeed_stage_3" in args.strategy):
|
||||||
|
my_save(
|
||||||
|
args,
|
||||||
|
trainer,
|
||||||
|
pl_module.state_dict(),
|
||||||
|
f"{args.proj_dir}/rwkv-final.pth",
|
||||||
|
)
|
||||||
|
exit(0)
|
||||||
|
if trainer.global_step < w_step:
|
||||||
|
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
||||||
|
|
||||||
|
if args.weight_decay_final > 0:
|
||||||
|
wd_now = args.weight_decay * math.exp(
|
||||||
|
math.log(args.weight_decay_final / args.weight_decay) * progress
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wd_now = args.weight_decay
|
||||||
|
|
||||||
|
for param_group in trainer.optimizers[0].param_groups:
|
||||||
|
if param_group["weight_decay"] > 0:
|
||||||
|
param_group["weight_decay"] = wd_now
|
||||||
|
if args.layerwise_lr > 0:
|
||||||
|
param_group["lr"] = lr * param_group["my_lr_scale"]
|
||||||
|
# print(param_group["lr"], param_group["my_lr_scale"])
|
||||||
|
else:
|
||||||
|
param_group["lr"] = lr
|
||||||
|
|
||||||
|
trainer.my_lr = lr
|
||||||
|
trainer.my_wd = wd_now
|
||||||
|
# rank_zero_info(f"{real_step} {lr}")
|
||||||
|
|
||||||
|
if trainer.global_step == 0:
|
||||||
|
if trainer.is_global_zero: # logging
|
||||||
|
trainer.my_loss_sum = 0
|
||||||
|
trainer.my_loss_count = 0
|
||||||
|
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
||||||
|
trainer.my_log.write(
|
||||||
|
f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
print(f"\n{trainer.strategy.config}\n")
|
||||||
|
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
trainer.my_log.flush()
|
||||||
|
if len(args.wandb) > 0:
|
||||||
|
print("Login to wandb...")
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
wandb.init(
|
||||||
|
project=args.wandb,
|
||||||
|
name=args.run_name + " " + args.my_timestamp,
|
||||||
|
config=args,
|
||||||
|
save_code=False,
|
||||||
|
)
|
||||||
|
trainer.my_wandb = wandb
|
||||||
|
|
||||||
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||||
|
args = self.args
|
||||||
|
token_per_step = args.ctx_len * args.real_bsz
|
||||||
|
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||||
|
if trainer.is_global_zero: # logging
|
||||||
|
t_now = time.time_ns()
|
||||||
|
kt_s = 0
|
||||||
|
try:
|
||||||
|
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
||||||
|
kt_s = token_per_step / t_cost / 1000
|
||||||
|
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
||||||
|
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
trainer.my_time_ns = t_now
|
||||||
|
if pl.__version__[0] == "2":
|
||||||
|
trainer.my_loss = outputs["loss"]
|
||||||
|
else:
|
||||||
|
trainer.my_loss = trainer.my_loss_all.float().mean().item()
|
||||||
|
trainer.my_loss_sum += trainer.my_loss
|
||||||
|
trainer.my_loss_count += 1
|
||||||
|
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
|
||||||
|
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
|
||||||
|
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
||||||
|
# self.log("s", real_step, prog_bar=True, on_step=True)
|
||||||
|
|
||||||
|
if len(args.wandb) > 0:
|
||||||
|
lll = {
|
||||||
|
"loss": trainer.my_loss,
|
||||||
|
"lr": trainer.my_lr,
|
||||||
|
"wd": trainer.my_wd,
|
||||||
|
"Gtokens": real_step * token_per_step / 1e9,
|
||||||
|
}
|
||||||
|
if kt_s > 0:
|
||||||
|
lll["kt/s"] = kt_s
|
||||||
|
trainer.my_wandb.log(lll, step=int(real_step))
|
||||||
|
if (trainer.is_global_zero) or (
|
||||||
|
"deepspeed_stage_3" in args.strategy
|
||||||
|
): # save pth
|
||||||
|
if args.magic_prime > 0:
|
||||||
|
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
||||||
|
if int(real_step) == int(
|
||||||
|
args.magic_prime * expand_factor // args.real_bsz
|
||||||
|
) - 1 + int(args.my_random_steps):
|
||||||
|
to_save_dict = pl_module.state_dict()
|
||||||
|
my_save(
|
||||||
|
args,
|
||||||
|
trainer,
|
||||||
|
to_save_dict,
|
||||||
|
f"{args.proj_dir}/rwkv-final.pth",
|
||||||
|
)
|
||||||
|
# if args.batch_save==batch_idx :
|
||||||
|
# to_save_dict = pl_module.state_dict()
|
||||||
|
# for name, state in to_save_dict.items():
|
||||||
|
# if 'img' in name:
|
||||||
|
# to_save_dict[name] = state
|
||||||
|
# try:
|
||||||
|
# my_save(
|
||||||
|
# args, trainer,
|
||||||
|
# to_save_dict,
|
||||||
|
# f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}-{batch_idx}.pth",
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# print('Error\n\n', e, '\n\n')
|
||||||
|
|
||||||
|
def on_train_epoch_start(self, trainer, pl_module):
|
||||||
|
args = self.args
|
||||||
|
if pl.__version__[0] == "2":
|
||||||
|
dataset = trainer.train_dataloader.dataset
|
||||||
|
else:
|
||||||
|
dataset = trainer.train_dataloader.dataset.datasets
|
||||||
|
assert "MyDataset" in str(dataset)
|
||||||
|
dataset.global_rank = trainer.global_rank
|
||||||
|
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
|
||||||
|
dataset.world_size = trainer.world_size
|
||||||
|
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
|
args = self.args
|
||||||
|
to_save_dict = {}
|
||||||
|
if (trainer.is_global_zero) or (
|
||||||
|
"deepspeed_stage_3" in args.strategy
|
||||||
|
): # save pth
|
||||||
|
if (
|
||||||
|
args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0
|
||||||
|
) or (trainer.current_epoch == args.epoch_count - 1):
|
||||||
|
if args.data_type == "wds_img":
|
||||||
|
raw_dict = pl_module.state_dict()
|
||||||
|
for k in raw_dict:
|
||||||
|
if k.startswith("encoder.") or k.startswith("decoder."):
|
||||||
|
to_save_dict[k] = raw_dict[k]
|
||||||
|
else:
|
||||||
|
to_save_dict = pl_module.state_dict()
|
||||||
|
|
||||||
|
if args.data_type == "img" and not args.lora:
|
||||||
|
for name, state in to_save_dict.items():
|
||||||
|
if "img" in name:
|
||||||
|
to_save_dict[name] = state
|
||||||
|
|
||||||
|
if args.lora:
|
||||||
|
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||||
|
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||||
|
lora_dict = {}
|
||||||
|
for name, state in to_save_dict.items():
|
||||||
|
if "img" in name:
|
||||||
|
lora_dict[name] = state
|
||||||
|
if (
|
||||||
|
".lora_" in name
|
||||||
|
or (enable_time_finetune and ".time_" in name)
|
||||||
|
or (enable_ln_finetune and ".ln" in name)
|
||||||
|
):
|
||||||
|
lora_dict[name] = state
|
||||||
|
to_save_dict = lora_dict
|
||||||
|
|
||||||
|
try:
|
||||||
|
my_save(
|
||||||
|
args,
|
||||||
|
trainer,
|
||||||
|
to_save_dict,
|
||||||
|
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error\n\n", e, "\n\n")
|
||||||
|
|
||||||
|
if trainer.is_global_zero: # logging
|
||||||
|
trainer.my_log.write(
|
||||||
|
f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n"
|
||||||
|
)
|
||||||
|
trainer.my_log.flush()
|
||||||
|
|
||||||
|
trainer.my_loss_sum = 0
|
||||||
|
trainer.my_loss_count = 0
|
||||||
|
if (args.epoch_begin + trainer.current_epoch) >= args.my_exit:
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def generate_init_weight(model, init_weight_name):
|
||||||
|
mm = model.generate_init_weight()
|
||||||
|
|
||||||
|
if model.args.my_pile_stage == 1:
|
||||||
|
if len(model.args.load_model) > 0:
|
||||||
|
print(f"Combine weights from {model.args.load_model}...")
|
||||||
|
load_dict = torch.load(model.args.load_model, map_location="cpu")
|
||||||
|
for k in load_dict:
|
||||||
|
try:
|
||||||
|
assert k in mm
|
||||||
|
except:
|
||||||
|
print("missing", k)
|
||||||
|
exit(0)
|
||||||
|
src = load_dict[k]
|
||||||
|
try:
|
||||||
|
mm[k] = src.reshape(mm[k].shape)
|
||||||
|
except:
|
||||||
|
tmp = mm[k].squeeze().clone()
|
||||||
|
print(k, src.shape, "-->", mm[k].shape)
|
||||||
|
ss = src.shape[0]
|
||||||
|
dd = tmp.shape[0]
|
||||||
|
for i in range(dd):
|
||||||
|
pos = i / dd * ss
|
||||||
|
if pos >= ss - 1:
|
||||||
|
tmp[i] = src[ss - 1]
|
||||||
|
else:
|
||||||
|
p0 = int(math.floor(pos))
|
||||||
|
ii = pos - p0
|
||||||
|
tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii)
|
||||||
|
mm[k] = tmp.reshape(mm[k].shape)
|
||||||
|
sss = src.squeeze().float().cpu().numpy()
|
||||||
|
print(sss[:10], "...", sss[-10:])
|
||||||
|
mmm = mm[k].squeeze().float().cpu().numpy()
|
||||||
|
print(mmm[:10], "...", mmm[-10:])
|
||||||
|
|
||||||
|
print(f"Save to {init_weight_name}...")
|
||||||
|
torch.save(mm, init_weight_name)
|
||||||
|
|
||||||
|
if model.args.my_pile_stage == 1:
|
||||||
|
print("Done. Now go for stage 2.")
|
||||||
|
exit(0)
|
139
finetune/lora/v6/src/utils.py
vendored
Normal file
139
finetune/lora/v6/src/utils.py
vendored
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import json, time, random, os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
time_slot = {}
|
||||||
|
time_ref = time.time_ns()
|
||||||
|
|
||||||
|
|
||||||
|
def record_time(name):
|
||||||
|
if name not in time_slot:
|
||||||
|
time_slot[name] = 1e20
|
||||||
|
tt = (time.time_ns() - time_ref) / 1e9
|
||||||
|
if tt < time_slot[name]:
|
||||||
|
time_slot[name] = tt
|
||||||
|
|
||||||
|
|
||||||
|
class TOKENIZER:
|
||||||
|
def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"):
|
||||||
|
if "list" in str(type(WORD_NAME)):
|
||||||
|
self.charMode = False
|
||||||
|
if WORD_NAME[0] == WORD_NAME[1]:
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
||||||
|
else:
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
|
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
||||||
|
self.vocab_size = len(self.tokenizer)
|
||||||
|
else:
|
||||||
|
self.charMode = True
|
||||||
|
with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file:
|
||||||
|
self.word_table = json.load(result_file)
|
||||||
|
|
||||||
|
self.vocab_size = len(self.word_table)
|
||||||
|
|
||||||
|
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
||||||
|
self.itos = {int(k): v for k, v in self.word_table.items()}
|
||||||
|
|
||||||
|
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||||
|
|
||||||
|
def refine_context(self, context):
|
||||||
|
context = context.strip().split("\n")
|
||||||
|
for c in range(len(context)):
|
||||||
|
context[c] = context[c].strip().strip("\u3000").strip("\r")
|
||||||
|
context = list(filter(lambda c: c != "", context))
|
||||||
|
context = "\n" + ("\n".join(context)).strip()
|
||||||
|
if context == "":
|
||||||
|
context = "\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
def sample_logits(
|
||||||
|
self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None
|
||||||
|
):
|
||||||
|
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||||
|
lastChar = int(x[-1])
|
||||||
|
|
||||||
|
probs = F.softmax(out, dim=-1)
|
||||||
|
|
||||||
|
if self.charMode:
|
||||||
|
if self.itos[lastChar] == "\n":
|
||||||
|
top_p = top_p_newline
|
||||||
|
else:
|
||||||
|
top_p = top_p_usual
|
||||||
|
else:
|
||||||
|
top_p = top_p_usual
|
||||||
|
|
||||||
|
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
|
||||||
|
probs = probs.numpy()
|
||||||
|
sorted_probs = np.sort(probs)[::-1]
|
||||||
|
cumulative_probs = np.cumsum(sorted_probs)
|
||||||
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||||
|
probs[probs < cutoff] = 0
|
||||||
|
if temperature != 1.0:
|
||||||
|
probs = probs.pow(1.0 / temperature)
|
||||||
|
probs = probs / np.sum(probs)
|
||||||
|
out = np.random.choice(a=len(probs), p=probs)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
sorted_probs = torch.sort(probs, descending=True)[0]
|
||||||
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
||||||
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||||
|
probs[probs < cutoff] = 0
|
||||||
|
if temperature != 1.0:
|
||||||
|
probs = probs.pow(1.0 / temperature)
|
||||||
|
out = torch.multinomial(probs, num_samples=1)[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def MaybeIsPrime(number):
|
||||||
|
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def FermatPrimalityTest(number):
|
||||||
|
if number > 1:
|
||||||
|
for time in range(3):
|
||||||
|
randomNumber = random.randint(2, number) - 1
|
||||||
|
if pow(randomNumber, number - 1, number) != 1:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def MillerRabinPrimalityTest(number):
|
||||||
|
if number == 2:
|
||||||
|
return True
|
||||||
|
elif number == 1 or number % 2 == 0:
|
||||||
|
return False
|
||||||
|
oddPartOfNumber = number - 1
|
||||||
|
timesTwoDividNumber = 0
|
||||||
|
while oddPartOfNumber % 2 == 0:
|
||||||
|
oddPartOfNumber = oddPartOfNumber // 2
|
||||||
|
timesTwoDividNumber = timesTwoDividNumber + 1
|
||||||
|
|
||||||
|
for time in range(3):
|
||||||
|
while True:
|
||||||
|
randomNumber = random.randint(2, number) - 1
|
||||||
|
if randomNumber != 0 and randomNumber != 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
|
||||||
|
|
||||||
|
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
||||||
|
iterationNumber = 1
|
||||||
|
|
||||||
|
while (iterationNumber <= timesTwoDividNumber - 1) and (
|
||||||
|
randomNumberWithPower != number - 1
|
||||||
|
):
|
||||||
|
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
||||||
|
iterationNumber = iterationNumber + 1
|
||||||
|
if randomNumberWithPower != (number - 1):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
435
finetune/lora/v6/train.py
vendored
Normal file
435
finetune/lora/v6/train.py
vendored
Normal file
@ -0,0 +1,435 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
rank_zero_info("########## work in progress ##########")
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
||||||
|
parser.add_argument(
|
||||||
|
"--wandb", default="", type=str
|
||||||
|
) # wandb project name. if "" then don't use wandb
|
||||||
|
parser.add_argument("--proj_dir", default="out", type=str)
|
||||||
|
parser.add_argument("--random_seed", default="-1", type=int)
|
||||||
|
|
||||||
|
parser.add_argument("--data_file", default="", type=str)
|
||||||
|
parser.add_argument("--data_type", default="utf-8", type=str)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocab_size", default=0, type=int
|
||||||
|
) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
||||||
|
|
||||||
|
parser.add_argument("--ctx_len", default=1024, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_steps", default=1000, type=int
|
||||||
|
) # a mini "epoch" has [epoch_steps] steps
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_count", default=500, type=int
|
||||||
|
) # train for this many "epochs". will continue afterwards with lr = lr_final
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_begin", default=0, type=int
|
||||||
|
) # if you load a model trained for x "epochs", set epoch_begin = x
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch_save", default=5, type=int
|
||||||
|
) # save the model every [epoch_save] "epochs"
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--micro_bsz", default=12, type=int
|
||||||
|
) # micro batch size (batch size per GPU)
|
||||||
|
parser.add_argument("--n_layer", default=6, type=int)
|
||||||
|
parser.add_argument("--n_embd", default=512, type=int)
|
||||||
|
parser.add_argument("--dim_att", default=0, type=int)
|
||||||
|
parser.add_argument("--dim_ffn", default=0, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pre_ffn", default=0, type=int
|
||||||
|
) # replace first att layer by ffn (sometimes better)
|
||||||
|
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
||||||
|
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
||||||
|
parser.add_argument(
|
||||||
|
"--tiny_att_layer", default=-999, type=int
|
||||||
|
) # tiny attention @ which layer
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_init", default=6e-4, type=float
|
||||||
|
) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
||||||
|
parser.add_argument("--lr_final", default=1e-5, type=float)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup_steps", default=-1, type=int
|
||||||
|
) # try 50 if you load a model
|
||||||
|
parser.add_argument("--beta1", default=0.9, type=float)
|
||||||
|
parser.add_argument(
|
||||||
|
"--beta2", default=0.99, type=float
|
||||||
|
) # use 0.999 when your model is close to convergence
|
||||||
|
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grad_cp", default=0, type=int
|
||||||
|
) # gradient checkpt: saves VRAM, but slower
|
||||||
|
parser.add_argument(
|
||||||
|
"--dropout", default=0, type=float
|
||||||
|
) # try 0.01 / 0.02 / 0.05 / 0.1
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight_decay", default=0, type=float
|
||||||
|
) # try 0.1 / 0.01 / 0.001
|
||||||
|
parser.add_argument("--weight_decay_final", default=-1, type=float)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--my_pile_version", default=1, type=int
|
||||||
|
) # my special pile version
|
||||||
|
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
||||||
|
parser.add_argument(
|
||||||
|
"--my_pile_shift", default=-1, type=int
|
||||||
|
) # my special pile mode - text shift
|
||||||
|
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--layerwise_lr", default=1, type=int
|
||||||
|
) # layerwise lr for faster convergence (but slower it/s)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_bucket_mb", default=200, type=int
|
||||||
|
) # deepspeed bucket size in MB. 200 seems enough
|
||||||
|
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
||||||
|
|
||||||
|
parser.add_argument("--my_sample_len", default=0, type=int)
|
||||||
|
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
||||||
|
parser.add_argument("--my_att_shift", default=1, type=int)
|
||||||
|
parser.add_argument(
|
||||||
|
"--head_size_a", default=64, type=int
|
||||||
|
) # can try larger values for larger models
|
||||||
|
parser.add_argument("--head_size_divisor", default=8, type=int)
|
||||||
|
parser.add_argument("--my_pos_emb", default=0, type=int)
|
||||||
|
parser.add_argument("--load_partial", default=0, type=int)
|
||||||
|
parser.add_argument("--magic_prime", default=0, type=int)
|
||||||
|
parser.add_argument("--my_qa_mask", default=0, type=int)
|
||||||
|
parser.add_argument("--my_random_steps", default=0, type=int)
|
||||||
|
parser.add_argument("--my_testing", default="", type=str)
|
||||||
|
parser.add_argument("--my_exit", default=99999999, type=int)
|
||||||
|
parser.add_argument("--my_exit_tokens", default=0, type=int)
|
||||||
|
|
||||||
|
# LORA
|
||||||
|
parser.add_argument("--emb", action="store_true")
|
||||||
|
parser.add_argument("--lora", action="store_true")
|
||||||
|
parser.add_argument("--lora_load", default="", type=str)
|
||||||
|
parser.add_argument("--lora_r", default=8, type=int)
|
||||||
|
parser.add_argument("--lora_alpha", default=32, type=float)
|
||||||
|
parser.add_argument("--lora_dropout", default=0.01, type=float)
|
||||||
|
parser.add_argument("--lora_parts", default="att,ln,time", type=str)
|
||||||
|
|
||||||
|
if pl.__version__[0] == "2":
|
||||||
|
parser.add_argument("--accelerator", default="gpu", type=str)
|
||||||
|
parser.add_argument("--strategy", default="auto", type=str)
|
||||||
|
parser.add_argument("--devices", default=1, type=int)
|
||||||
|
parser.add_argument("--num_nodes", default=1, type=int)
|
||||||
|
parser.add_argument("--precision", default="fp16", type=str)
|
||||||
|
parser.add_argument("--accumulate_grad_batches", default=1, type=int)
|
||||||
|
else:
|
||||||
|
parser = Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import os, warnings, math, datetime, sys, time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
if "deepspeed" in args.strategy:
|
||||||
|
import deepspeed
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
|
if args.random_seed >= 0:
|
||||||
|
print(
|
||||||
|
f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n"
|
||||||
|
* 3
|
||||||
|
)
|
||||||
|
seed_everything(args.random_seed)
|
||||||
|
|
||||||
|
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", ".*Consider increasing the value of the `num_workers` argument*"
|
||||||
|
)
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", ".*The progress bar already tracks a metric with the*"
|
||||||
|
)
|
||||||
|
# os.environ["WDS_SHOW_SEED"] = "1"
|
||||||
|
|
||||||
|
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
|
args.enable_checkpointing = False
|
||||||
|
args.replace_sampler_ddp = False
|
||||||
|
args.logger = False
|
||||||
|
args.gradient_clip_val = 1.0
|
||||||
|
args.num_sanity_val_steps = 0
|
||||||
|
args.check_val_every_n_epoch = int(1e20)
|
||||||
|
args.log_every_n_steps = int(1e20)
|
||||||
|
args.max_epochs = args.epoch_count # -1 continue forever
|
||||||
|
args.betas = (args.beta1, args.beta2)
|
||||||
|
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
||||||
|
os.environ["RWKV_MY_TESTING"] = args.my_testing
|
||||||
|
os.environ["RWKV_CTXLEN"] = str(args.ctx_len)
|
||||||
|
os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
|
||||||
|
if args.dim_att <= 0:
|
||||||
|
args.dim_att = args.n_embd
|
||||||
|
if args.dim_ffn <= 0:
|
||||||
|
args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
|
||||||
|
|
||||||
|
if args.data_type == "wds_img":
|
||||||
|
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
||||||
|
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
||||||
|
else:
|
||||||
|
args.run_name = (
|
||||||
|
f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
||||||
|
)
|
||||||
|
if not os.path.exists(args.proj_dir):
|
||||||
|
os.makedirs(args.proj_dir)
|
||||||
|
|
||||||
|
if args.my_pile_stage > 0:
|
||||||
|
magic_prime_bak = args.magic_prime
|
||||||
|
|
||||||
|
if args.my_pile_shift < 0:
|
||||||
|
args.my_pile_shift = 0
|
||||||
|
|
||||||
|
if magic_prime_bak > 0:
|
||||||
|
args.magic_prime = magic_prime_bak
|
||||||
|
if args.my_qa_mask == 2:
|
||||||
|
args.epoch_count = 2 * args.magic_prime // 40320
|
||||||
|
else:
|
||||||
|
args.epoch_count = args.magic_prime // 40320
|
||||||
|
|
||||||
|
args.epoch_steps = 40320 // args.real_bsz
|
||||||
|
assert args.epoch_steps * args.real_bsz == 40320
|
||||||
|
# if args.my_pile_stage == 2:
|
||||||
|
# assert args.lr_final == args.lr_init
|
||||||
|
if args.my_pile_stage >= 2: # find latest saved model
|
||||||
|
list_p = []
|
||||||
|
for p in os.listdir(args.proj_dir):
|
||||||
|
if p.startswith("rwkv") and p.endswith(".pth"):
|
||||||
|
p = ((p.split("-"))[1].split("."))[0]
|
||||||
|
if p != "final":
|
||||||
|
if p == "init":
|
||||||
|
p = -1
|
||||||
|
else:
|
||||||
|
p = int(p)
|
||||||
|
list_p += [p]
|
||||||
|
list_p.sort()
|
||||||
|
max_p = list_p[-1]
|
||||||
|
if len(list_p) > 1:
|
||||||
|
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
|
||||||
|
if max_p == -1:
|
||||||
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||||
|
else:
|
||||||
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||||
|
if args.warmup_steps < 0:
|
||||||
|
if args.my_pile_stage == 2:
|
||||||
|
args.warmup_steps = 10
|
||||||
|
else:
|
||||||
|
args.warmup_steps = 30
|
||||||
|
args.epoch_begin = max_p + 1
|
||||||
|
|
||||||
|
samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||||
|
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
||||||
|
try:
|
||||||
|
deepspeed_version = deepspeed.__version__
|
||||||
|
except:
|
||||||
|
deepspeed_version = None
|
||||||
|
pass
|
||||||
|
rank_zero_info(
|
||||||
|
f"""
|
||||||
|
############################################################################
|
||||||
|
#
|
||||||
|
# RWKV-5 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
|
||||||
|
#
|
||||||
|
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
||||||
|
#
|
||||||
|
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1}, save every {args.epoch_save} epoch
|
||||||
|
#
|
||||||
|
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
||||||
|
#
|
||||||
|
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
||||||
|
#
|
||||||
|
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
||||||
|
#
|
||||||
|
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
|
||||||
|
# Found deepspeed {deepspeed_version}, recommend 0.7.0 (faster than newer versions)
|
||||||
|
# Found pytorch_lightning {pl.__version__}, recommend 1.9.5
|
||||||
|
#
|
||||||
|
############################################################################
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
rank_zero_info(str(vars(args)) + "\n")
|
||||||
|
|
||||||
|
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "uint16"]
|
||||||
|
|
||||||
|
if args.lr_final == 0 or args.lr_init == 0:
|
||||||
|
rank_zero_info(
|
||||||
|
"\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
||||||
|
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
||||||
|
if args.precision == "fp32":
|
||||||
|
for i in range(10):
|
||||||
|
rank_zero_info(
|
||||||
|
"\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n"
|
||||||
|
)
|
||||||
|
if args.precision == "fp16":
|
||||||
|
rank_zero_info(
|
||||||
|
"\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
os.environ["RWKV_JIT_ON"] = "0"
|
||||||
|
if "deepspeed_stage_3" in args.strategy:
|
||||||
|
os.environ["RWKV_JIT_ON"] = "0"
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
if args.precision == "fp32":
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
else:
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
if "32" in args.precision:
|
||||||
|
args.precision = 32
|
||||||
|
elif args.precision == "fp16":
|
||||||
|
args.precision = 16
|
||||||
|
else:
|
||||||
|
args.precision = "bf16"
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
from src.trainer import train_callback, generate_init_weight
|
||||||
|
from src.dataset import MyDataset
|
||||||
|
|
||||||
|
train_data = MyDataset(args)
|
||||||
|
args.vocab_size = train_data.vocab_size
|
||||||
|
|
||||||
|
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
||||||
|
|
||||||
|
model = RWKV(args)
|
||||||
|
|
||||||
|
if args.lora:
|
||||||
|
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
||||||
|
LORA_CONFIG["r"] = args.lora_r
|
||||||
|
LORA_CONFIG["alpha"] = args.lora_alpha
|
||||||
|
LORA_CONFIG["dropout"] = args.lora_dropout
|
||||||
|
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
|
||||||
|
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||||
|
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||||
|
model.requires_grad_(False)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
|
||||||
|
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
|
||||||
|
print(f" LoRA additionally training module {name}")
|
||||||
|
for pname, param in module.named_parameters():
|
||||||
|
param.requires_grad = "lora_" in pname
|
||||||
|
elif enable_ln_finetune and ".ln" in name:
|
||||||
|
print(f" LoRA additionally training module {name}")
|
||||||
|
for param in module.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
elif enable_time_finetune and any(
|
||||||
|
n.startswith("time") for n, _ in module.named_parameters()
|
||||||
|
):
|
||||||
|
for pname, param in module.named_parameters():
|
||||||
|
if pname.startswith("time"):
|
||||||
|
print(f" LoRA additionally training parameter {pname}")
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(args.load_model) == 0 or args.my_pile_stage == 1
|
||||||
|
): # shall we build the initial weights?
|
||||||
|
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
||||||
|
generate_init_weight(model, init_weight_name) # save initial weights
|
||||||
|
args.load_model = init_weight_name
|
||||||
|
|
||||||
|
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
||||||
|
try:
|
||||||
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||||
|
load_keys = list(load_dict.keys())
|
||||||
|
for k in load_keys:
|
||||||
|
if k.startswith("_forward_module."):
|
||||||
|
load_dict[k.replace("_forward_module.", "")] = load_dict[k]
|
||||||
|
del load_dict[k]
|
||||||
|
except:
|
||||||
|
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
||||||
|
if args.my_pile_stage >= 2: # try again using another checkpoint
|
||||||
|
max_p = args.my_pile_prev_p
|
||||||
|
if max_p == -1:
|
||||||
|
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||||
|
else:
|
||||||
|
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||||
|
args.epoch_begin = max_p + 1
|
||||||
|
rank_zero_info(f"Trying {args.load_model}")
|
||||||
|
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||||
|
|
||||||
|
if args.load_partial == 1:
|
||||||
|
load_keys = load_dict.keys()
|
||||||
|
for k in model.state_dict():
|
||||||
|
if k not in load_keys:
|
||||||
|
load_dict[k] = model.state_dict()[k]
|
||||||
|
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||||
|
if os.path.isfile(args.lora_load):
|
||||||
|
model.load_state_dict(
|
||||||
|
torch.load(args.lora_load, map_location="cpu"), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if pl.__version__[0] == "2":
|
||||||
|
trainer = Trainer(
|
||||||
|
accelerator=args.accelerator,
|
||||||
|
strategy=args.strategy,
|
||||||
|
devices=args.devices,
|
||||||
|
num_nodes=args.num_nodes,
|
||||||
|
precision=args.precision,
|
||||||
|
logger=args.logger,
|
||||||
|
callbacks=[train_callback(args)],
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
||||||
|
num_sanity_val_steps=args.num_sanity_val_steps,
|
||||||
|
log_every_n_steps=args.log_every_n_steps,
|
||||||
|
enable_checkpointing=args.enable_checkpointing,
|
||||||
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
|
gradient_clip_val=args.gradient_clip_val,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
trainer = Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[train_callback(args)],
|
||||||
|
)
|
||||||
|
|
||||||
|
if trainer.global_rank == 0:
|
||||||
|
for n in model.state_dict():
|
||||||
|
shape = model.state_dict()[n].shape
|
||||||
|
shape = [i for i in shape if i != 1]
|
||||||
|
if len(shape) > 1:
|
||||||
|
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
|
||||||
|
else:
|
||||||
|
print(f"{str(shape[0]).ljust(5)} {n}")
|
||||||
|
|
||||||
|
if "deepspeed" in args.strategy:
|
||||||
|
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
|
||||||
|
args.ds_bucket_mb * 1000 * 1000
|
||||||
|
)
|
||||||
|
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
|
||||||
|
args.ds_bucket_mb * 1000 * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
||||||
|
data_loader = DataLoader(
|
||||||
|
train_data,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=True,
|
||||||
|
batch_size=args.micro_bsz,
|
||||||
|
num_workers=1,
|
||||||
|
persistent_workers=False,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.fit(model, data_loader)
|
Loading…
Reference in New Issue
Block a user