This commit is contained in:
josc146 2024-05-28 22:35:47 +08:00
parent 3488d22d22
commit f05a4acb04
138 changed files with 29047 additions and 334 deletions

311
finetune/lora/v6/cuda/wkv6infctx_cuda.cu vendored Normal file
View File

@ -0,0 +1,311 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, F *__restrict__ _s,
F *__restrict__ const _y)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_u += h*_N_;
_s += h*_N_*_N_ + i*_N_;
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
float state[_N_];
__syncthreads();
u[i] = float(_u[i]);
__syncthreads();
for (int j = 0; j < _N_; j++) {
state[j] = float(_s[j]);
}
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
{
__syncthreads();
w[i] = __expf(-__expf(float(_w[t])));
r[i] = float(_r[t]);
k[i] = float(_k[t]);
__syncthreads();
const float v = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j+=4)
{
const float4& r_ = (float4&)(r[j]);
const float4& k_ = (float4&)(k[j]);
const float4& w_ = (float4&)(w[j]);
const float4& u_ = (float4&)(u[j]);
float4& s = (float4&)(state[j]);
float4 x;
x.x = k_.x * v;
x.y = k_.y * v;
x.z = k_.z * v;
x.w = k_.w * v;
y += r_.x * (u_.x * x.x + s.x);
y += r_.y * (u_.y * x.y + s.y);
y += r_.z * (u_.z * x.z + s.z);
y += r_.w * (u_.w * x.w + s.w);
s.x = s.x * w_.x + x.x;
s.y = s.y * w_.y + x.y;
s.z = s.z * w_.z + x.z;
s.w = s.w * w_.w + x.w;
}
_y[t] = F(y);
}
#pragma unroll
for (int j = 0; j < _N_; j++)
_s[j] = F(state[j]);
}
template <typename F>
__global__ void kernel_backward_111(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy,
F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu, F *__restrict__ const _gs)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_u += h*_N_;
_s += h*_N_*_N_ + i;
__shared__ float u_[_N_];
__shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
__syncthreads();
u_[i] = float(_u[i]);
__syncthreads();
const float u = u_[i];
float state[_N_], scccc[_N_] = {0}, sdddd[_N_] = {0}, sssss[_N_] = {0}, swwww[_N_];
for (int j = 0; j < _N_; j++) {
state[j] = float(_s[j*_N_]);
swwww[j] = 1.0;
}
const int t_0 = b*T*C + h*_N_ + i;
const int t_T_1 = t_0 + (T-1)*C;
const int t_T = t_0 + T*C;
float gu = 0;
for (int t = t_0; t < t_T; t += C)
{
__syncthreads();
v[i] = float(_v[t]);
gy[i] = float(_gy[t]);
__syncthreads();
const float k = float(_k[t]);
const float w = __expf(-__expf(float(_w[t])));
float gr = 0, gu_ = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = state[j];
float x = k * v[j];
gr += (u * x + s) * gy[j];
gu_ += x * gy[j];
s = s * w + x;
}
_gr[t] = F(gr);
gu += float(_r[t]) * gu_;
}
_gu[b*C + h*_N_ + i] = F(gu);
for (int t = t_T_1; t >= t_0; t -= C)
{
__syncthreads();
v[i] = float(_v[t]);
gy[i] = float(_gy[t]);
__syncthreads();
const float rr = float(_r[t]);
const float w = __expf(-__expf(float(_w[t])));
float gk = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = scccc[j];
float x = rr * gy[j];
gk += (u * x + s) * v[j];
s = x + s * w;
}
_gk[t] = F(gk);
}
for (int t = t_T_1; t >= t_0; t -= C)
{
__syncthreads();
r[i] = float(_r[t]);
k[i] = float(_k[t]);
w_[i] = __expf(-__expf(float(_w[t])));
__syncthreads();
const float gyy = float(_gy[t]);
float gv = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = sdddd[j];
float x = gyy * r[j];
gv += (u_[j] * x + s) * k[j];
s = x + s * w_[j];
}
_gv[t] = F(gv);
}
for (int t = t_0; t < t_T; t += C)
{
__syncthreads();
r[i] = float(_r[t]);
w_[i] = __expf(-__expf(float(_w[t])));
__syncthreads();
const float gyy = float(_gy[t]);
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& w = swwww[j];
sssss[j] += gyy * w * r[j];
w *= w_[j];
}
}
for (int j = 0; j < _N_; j++)
_gs[b*H*_N_*_N_ + h*_N_*_N_ + i*_N_ + j] = F(sssss[j]);
}
template <typename F>
__global__ void kernel_backward_222(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy,
F *__restrict__ const _gw)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_s += h*_N_*_N_ + i;
__shared__ float v[_N_], gy[_N_];
float state[_N_], saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0};
for (int j = 0; j < _N_; j++) {
state[j] = float(_s[j*_N_]);
}
const int t_0 = b*T*C + h*_N_ + i;
const int t_1 = t_0 + C;
const int t_2 = t_0 + 2*C;
const int t_T_1 = t_0 + (T-1)*C;
for (int t = t_T_1; t > t_1; t -= C)
{
__syncthreads();
gy[i] = float(_gy[t]);
v[i] = float(_v[t-2*C]);
__syncthreads();
const float r = float(_r[t]);
const float w = __expf(-__expf(float(_w[t-C])));
float sum = 0.0f;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = saaaa[j];
s = (s + r * gy[j]) * w;
sum += s * v[j];
}
sbbbb[(t-t_1)/C] = sum * float(_k[t-2*C]);
}
{
__syncthreads();
gy[i] = float(_gy[t_1]);
__syncthreads();
const float r = float(_r[t_1]);
const float w = __expf(-__expf(float(_w[t_0])));
float sum = 0.0f;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = saaaa[j];
s = (s + r * gy[j]) * w;
sum += s * state[j];
}
sbbbb[0] = sum;
}
float sss = sbbbb[0];
_gw[t_0] = F(sss * -__expf(float(_w[t_0])));
{
__syncthreads();
gy[i] = float(_gy[t_1]);
__syncthreads();
const float w = __expf(-__expf(float(_w[t_0])));
float sum = 0.0f;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = scccc[j];
s = (s + state[j]) * w;
sum += s * gy[j];
}
sss += sbbbb[1] - (sum * float(_r[t_1]));
_gw[t_1] = F(sss * -__expf(float(_w[t_1])));
}
for (int t = t_2; t < t_T_1; t += C)
{
__syncthreads();
gy[i] = float(_gy[t]);
v[i] = float(_v[t-2*C]);
__syncthreads();
const float w = __expf(-__expf(float(_w[t-C])));
const float k = float(_k[t-2*C]);
float sum = 0.0f;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = scccc[j];
s = (s + k * v[j]) * w;
sum += s * gy[j];
}
sss += sbbbb[(t-t_0)/C] - (sum * float(_r[t]));
_gw[t] = F(sss * -__expf(float(_w[t])));
}
_gw[t_T_1] = 0;
}
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *y)
{
assert(H*_N_ == C);
assert(_N_%4 == 0);
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, y);
}
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs)
{
assert(H*_N_ == C);
assert(_N_%4 == 0);
kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, gy, gr, gk, gv, gu, gs);
kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, gy, gw);
}

22
finetune/lora/v6/cuda/wkv6infctx_op.cpp vendored Normal file
View File

@ -0,0 +1,22 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y);
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs);
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) {
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) {
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gs.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv6state forward");
m.def("backward", &backward, "wkv6state backward");
}
TORCH_LIBRARY(wkv6state, m) {
m.def("forward", forward);
m.def("backward", backward);
}

311
finetune/lora/v6/cuda/wkv6state_cuda.cu vendored Normal file
View File

@ -0,0 +1,311 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u,const F *__restrict__ _s,
F *__restrict__ const _y)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_u += h*_N_;
_s += h*_N_*_N_ + i*_N_;
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
float state[_N_];
__syncthreads();
u[i] = float(_u[i]);
__syncthreads();
for (int j = 0; j < _N_; j++) {
state[j] = float(_s[j]);
}
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
{
__syncthreads();
w[i] = __expf(-__expf(float(_w[t])));
r[i] = float(_r[t]);
k[i] = float(_k[t]);
__syncthreads();
const float v = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j+=4)
{
const float4& r_ = (float4&)(r[j]);
const float4& k_ = (float4&)(k[j]);
const float4& w_ = (float4&)(w[j]);
const float4& u_ = (float4&)(u[j]);
float4& s = (float4&)(state[j]);
float4 x;
x.x = k_.x * v;
x.y = k_.y * v;
x.z = k_.z * v;
x.w = k_.w * v;
y += r_.x * (u_.x * x.x + s.x);
y += r_.y * (u_.y * x.y + s.y);
y += r_.z * (u_.z * x.z + s.z);
y += r_.w * (u_.w * x.w + s.w);
s.x = s.x * w_.x + x.x;
s.y = s.y * w_.y + x.y;
s.z = s.z * w_.z + x.z;
s.w = s.w * w_.w + x.w;
}
_y[t] = F(y);
}
// #pragma unroll
// for (int j = 0; j < _N_; j++)
// _s[j] = F(state[j]);
}
template <typename F>
__global__ void kernel_backward_111(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy,
F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu, F *__restrict__ const _gs)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_u += h*_N_;
_s += h*_N_*_N_ + i;
__shared__ float u_[_N_];
__shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
__syncthreads();
u_[i] = float(_u[i]);
__syncthreads();
const float u = u_[i];
float state[_N_], scccc[_N_] = {0}, sdddd[_N_] = {0}, sssss[_N_] = {0}, swwww[_N_];
for (int j = 0; j < _N_; j++) {
state[j] = float(_s[j*_N_]);
swwww[j] = 1.0;
}
const int t_0 = b*T*C + h*_N_ + i;
const int t_T_1 = t_0 + (T-1)*C;
const int t_T = t_0 + T*C;
float gu = 0;
for (int t = t_0; t < t_T; t += C)
{
__syncthreads();
v[i] = float(_v[t]);
gy[i] = float(_gy[t]);
__syncthreads();
const float k = float(_k[t]);
const float w = __expf(-__expf(float(_w[t])));
float gr = 0, gu_ = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = state[j];
float x = k * v[j];
gr += (u * x + s) * gy[j];
gu_ += x * gy[j];
s = s * w + x;
}
_gr[t] = F(gr);
gu += float(_r[t]) * gu_;
}
_gu[b*C + h*_N_ + i] = F(gu);
for (int t = t_T_1; t >= t_0; t -= C)
{
__syncthreads();
v[i] = float(_v[t]);
gy[i] = float(_gy[t]);
__syncthreads();
const float rr = float(_r[t]);
const float w = __expf(-__expf(float(_w[t])));
float gk = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = scccc[j];
float x = rr * gy[j];
gk += (u * x + s) * v[j];
s = x + s * w;
}
_gk[t] = F(gk);
}
for (int t = t_T_1; t >= t_0; t -= C)
{
__syncthreads();
r[i] = float(_r[t]);
k[i] = float(_k[t]);
w_[i] = __expf(-__expf(float(_w[t])));
__syncthreads();
const float gyy = float(_gy[t]);
float gv = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = sdddd[j];
float x = gyy * r[j];
gv += (u_[j] * x + s) * k[j];
s = x + s * w_[j];
}
_gv[t] = F(gv);
}
for (int t = t_0; t < t_T; t += C)
{
__syncthreads();
r[i] = float(_r[t]);
w_[i] = __expf(-__expf(float(_w[t])));
__syncthreads();
const float gyy = float(_gy[t]);
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& w = swwww[j];
sssss[j] += gyy * w * r[j];
w *= w_[j];
}
}
for (int j = 0; j < _N_; j++)
_gs[b*H*_N_*_N_ + h*_N_*_N_ + i*_N_ + j] = F(sssss[j]);
}
template <typename F>
__global__ void kernel_backward_222(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy,
F *__restrict__ const _gw)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_s += h*_N_*_N_ + i;
__shared__ float v[_N_], gy[_N_];
float state[_N_], saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0};
for (int j = 0; j < _N_; j++) {
state[j] = float(_s[j*_N_]);
}
const int t_0 = b*T*C + h*_N_ + i;
const int t_1 = t_0 + C;
const int t_2 = t_0 + 2*C;
const int t_T_1 = t_0 + (T-1)*C;
for (int t = t_T_1; t > t_1; t -= C)
{
__syncthreads();
gy[i] = float(_gy[t]);
v[i] = float(_v[t-2*C]);
__syncthreads();
const float r = float(_r[t]);
const float w = __expf(-__expf(float(_w[t-C])));
float sum = 0.0f;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = saaaa[j];
s = (s + r * gy[j]) * w;
sum += s * v[j];
}
sbbbb[(t-t_1)/C] = sum * float(_k[t-2*C]);
}
{
__syncthreads();
gy[i] = float(_gy[t_1]);
__syncthreads();
const float r = float(_r[t_1]);
const float w = __expf(-__expf(float(_w[t_0])));
float sum = 0.0f;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = saaaa[j];
s = (s + r * gy[j]) * w;
sum += s * state[j];
}
sbbbb[0] = sum;
}
float sss = sbbbb[0];
_gw[t_0] = F(sss * -__expf(float(_w[t_0])));
{
__syncthreads();
gy[i] = float(_gy[t_1]);
__syncthreads();
const float w = __expf(-__expf(float(_w[t_0])));
float sum = 0.0f;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = scccc[j];
s = (s + state[j]) * w;
sum += s * gy[j];
}
sss += sbbbb[1] - (sum * float(_r[t_1]));
_gw[t_1] = F(sss * -__expf(float(_w[t_1])));
}
for (int t = t_2; t < t_T_1; t += C)
{
__syncthreads();
gy[i] = float(_gy[t]);
v[i] = float(_v[t-2*C]);
__syncthreads();
const float w = __expf(-__expf(float(_w[t-C])));
const float k = float(_k[t-2*C]);
float sum = 0.0f;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = scccc[j];
s = (s + k * v[j]) * w;
sum += s * gy[j];
}
sss += sbbbb[(t-t_0)/C] - (sum * float(_r[t]));
_gw[t] = F(sss * -__expf(float(_w[t])));
}
_gw[t_T_1] = 0;
}
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *y)
{
assert(H*_N_ == C);
assert(_N_%4 == 0);
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, y);
}
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs)
{
assert(H*_N_ == C);
assert(_N_%4 == 0);
kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, gy, gr, gk, gv, gu, gs);
kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, gy, gw);
}

22
finetune/lora/v6/cuda/wkv6state_op.cpp vendored Normal file
View File

@ -0,0 +1,22 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y);
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs);
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) {
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) {
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gs.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv6state forward");
m.def("backward", &backward, "wkv6state backward");
}
TORCH_LIBRARY(wkv6state, m) {
m.def("forward", forward);
m.def("backward", backward);
}

View File

@ -0,0 +1,16 @@
base_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth'
lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth'
lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth'
output='/home/rwkv/JL/model/nf4-world.pth'
QUANT='nf4' #follow train
TYPE='lora'
Lora_alpha=128
python merge/merge.py --base_model $base_model \
--lora_init $lora_init \
--lora_checkpoint $lora_checkpoint \
--output $output \
--quant $QUANT \
--type $TYPE \
--lora_alpha $Lora_alpha

27
finetune/lora/v6/demo/demo-lora.sh vendored Normal file
View File

@ -0,0 +1,27 @@
load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth'
proj_dir='/home/rwkv/JL/out_model/nf4'
data_file='/home/rwkv/JL/data/roleplay'
QUANT='nf4' #4bit nf4 fp4 none
lora_r=64
lora_alpha=128
n_layer=32
n_embd=4096
micro_bsz=8
epoch_save=1
epoch_steps=1000
ctx_len=1024
python train.py --load_model $load_model \
--proj_dir $proj_dir --data_file $data_file \
--data_type binidx --vocab_size 65536 \
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
--n_layer $n_layer --n_embd $n_embd \
--pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
--my_testing "x060" \
--lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha $lora_alpha --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
--quant $QUANT

View File

@ -0,0 +1,15 @@
base_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2-20240208-ctx4096.pth'
lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth'
lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth'
output='/home/rwkv/JL/model/end-world.pth'
QUANT='nf4' #follow train
TYPE='pissa'
python merge/merge.py --base_model $base_model \
--lora_init $lora_init \
--lora_checkpoint $lora_checkpoint \
--output $output \
--quant $QUANT \
--type $TYPE

40
finetune/lora/v6/demo/demo-pissa.sh vendored Normal file
View File

@ -0,0 +1,40 @@
load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth'
proj_dir='/home/rwkv/JL/out_model/nf4'
data_file='/home/rwkv/JL/data/end_text_document'
QUANT='nf4' #4bit nf4 fp4 none
svd_niter=4
lora_r=64
n_layer=24
n_embd=2048
micro_bsz=8
epoch_save=1
epoch_steps=1000
ctx_len=1024
python train.py --load_model $load_model \
--proj_dir $proj_dir --data_file $data_file \
--data_type binidx --vocab_size 65536 \
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
--n_layer $n_layer --n_embd $n_embd \
--pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
--my_testing "x060" \
--lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
--PISSA --svd_niter $svd_niter \
--dataload pad
###remove load_model
# python train.py --proj_dir $proj_dir --data_file $data_file \
# --data_type binidx --vocab_size 65536 \
# --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
# --n_layer $n_layer --n_embd $n_embd \
# --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
# --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
# --my_testing "x060" \
# --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
# --PISSA --svd_niter $svd_niter \
# --quant $QUANT

27
finetune/lora/v6/demo/demo-qpissa-pt.sh vendored Normal file
View File

@ -0,0 +1,27 @@
load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth'
proj_dir='/home/rwkv/JL/out_model/nf4'
data_file='/home/rwkv/JL/data/roleplay'
QUANT='nf4' #4bit nf4 fp4 none
svd_niter=4
lora_r=64
n_layer=32
n_embd=4096
micro_bsz=4
epoch_save=1
epoch_steps=1000
ctx_len=1024
python train.py --proj_dir $proj_dir --data_file $data_file \
--data_type binidx --vocab_size 65536 \
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
--n_layer $n_layer --n_embd $n_embd \
--pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
--my_testing "x060" \
--lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
--PISSA --svd_niter $svd_niter \
--quant $QUANT

View File

@ -0,0 +1,8 @@
base_model='/home/rwkv/JL/model/RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth'
state_checkpoint='/home/rwkv/JL/out_model/state/rwkv-9.pth'
output='/home/rwkv/JL/model/state-0.pth'
python merge/merge_state.py --base_model $base_model \
--state_checkpoint $state_checkpoint \
--output $output

View File

@ -0,0 +1,22 @@
load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth'
proj_dir='/home/rwkv/JL/out_model/state'
data_file='/home/rwkv/JL/data/end_text_document'
n_layer=24
n_embd=2048
micro_bsz=1
epoch_save=1
epoch_steps=1000
ctx_len=1024
python train.py --load_model $load_model \
--proj_dir $proj_dir --data_file $data_file \
--data_type binidx --vocab_size 65536 \
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
--n_layer $n_layer --n_embd $n_embd \
--pre_ffn 0 --head_qk 0 --lr_init 1 --lr_final 1e-1 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 0 \
--my_testing "x060" \
--train_type "state" --dataload pad --wandb fla --fla

View File

@ -0,0 +1,27 @@
#!/bin/bash
# Create data directory
mkdir -p data
# Download minipile (1498226207 tokens, around 3GB)
wget --continue -O data/minipile.idx https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.idx
wget --continue -O data/minipile.bin https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.bin
# Generate initial model (L12-D768 = 169M)
BASE_NAME="model/0.1-1"
N_LAYER="12"
N_EMBD="768"
# magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)
# use https://www.dcode.fr/prime-numbers-search
python train.py --wandb "" --proj_dir $BASE_NAME \
--data_file "data/minipile" --data_type "binidx" --vocab_size 65536 \
--ctx_len 512 --my_pile_stage 1 --epoch_count 1 --epoch_begin 0 \
--epoch_save 1 --weight_decay 0 --head_size_a 64 \
--num_nodes 1 --micro_bsz 1 --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 --my_exit_tokens 1498226207 --magic_prime 2926181 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 \
--accelerator cpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 --enable_progress_bar False --ds_bucket_mb 200

View File

@ -0,0 +1,21 @@
#!/bin/bash
BASE_NAME="model/0.1-1"
N_LAYER="12"
N_EMBD="768"
M_BSZ="16" # takes 16G VRAM (reduce this to save VRAM)
LR_INIT="6e-4"
LR_FINAL="6e-5"
GRAD_CP=0 # set to 1 to save VRAM (will be slower)
EPOCH_SAVE=10
# magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)
# use https://www.dcode.fr/prime-numbers-search
python train.py --load_model "0" --wandb "RWKV-5-Test" --proj_dir $BASE_NAME \
--ctx_len 512 --my_pile_stage 3 --epoch_count 999999 --epoch_begin 0 \
--data_file "data/minipile" --my_exit_tokens 1498226207 --magic_prime 2926181 \
--num_nodes 1 --micro_bsz $M_BSZ --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 \
--lr_init $LR_INIT --lr_final $LR_FINAL --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 --data_type "binidx" --vocab_size 65536 \
--weight_decay 0.001 --epoch_save $EPOCH_SAVE --head_size_a 64 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp $GRAD_CP --enable_progress_bar True --ds_bucket_mb 200

182
finetune/lora/v6/demo/demo.jsonl vendored Normal file

File diff suppressed because one or more lines are too long

25
finetune/lora/v6/demo/infctx.sh vendored Normal file
View File

@ -0,0 +1,25 @@
load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth'
proj_dir='/home/rwkv/JL/out_model/infctx'
data_file='/home/rwkv/JL/data/roleplay'
n_layer=24
n_embd=2048
micro_bsz=8
epoch_save=5
epoch_steps=1000
ctx_len=16384
chunk_ctx=2048
python train.py --load_model $load_model \
--proj_dir $proj_dir --data_file $data_file \
--data_type binidx --vocab_size 65536 \
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
--n_layer $n_layer --n_embd $n_embd \
--pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
--lora_load rwkv-0 --lora --lora_r 64 --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
--my_testing "x060" --dataload pad \
--train_type infctx --chunk_ctx $chunk_ctx --fla --wandb infctx

50
finetune/lora/v6/fla/__init__.py vendored Normal file
View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
from fla.layers import (ABCAttention, BasedLinearAttention, DeltaNet,
GatedLinearAttention, HGRN2Attention, LinearAttention,
MultiScaleRetention, ReBasedLinearAttention)
from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM,
DeltaNetModel, GLAForCausalLM, GLAModel,
HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM,
HGRNModel, LinearAttentionForCausalLM,
LinearAttentionModel, RetNetForCausalLM, RetNetModel,
RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM,
TransformerModel)
from fla.ops import (chunk_gla, chunk_retention, fused_chunk_based,
fused_chunk_gla, fused_chunk_retention)
__all__ = [
'ABCAttention',
'BasedLinearAttention',
'DeltaNet',
'HGRN2Attention',
'GatedLinearAttention',
'LinearAttention',
'MultiScaleRetention',
'ReBasedLinearAttention',
'ABCForCausalLM',
'ABCModel',
'DeltaNetForCausalLM',
'DeltaNetModel',
'HGRNForCausalLM',
'HGRNModel',
'HGRN2ForCausalLM',
'HGRN2Model',
'GLAForCausalLM',
'GLAModel',
'LinearAttentionForCausalLM',
'LinearAttentionModel',
'RetNetForCausalLM',
'RetNetModel',
'RWKV6ForCausalLM',
'RWKV6Model',
'TransformerForCausalLM',
'TransformerModel',
'chunk_gla',
'chunk_retention',
'fused_chunk_based',
'fused_chunk_gla',
'fused_chunk_retention'
]
__version__ = '0.1'

25
finetune/lora/v6/fla/layers/__init__.py vendored Normal file
View File

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
from .abc import ABCAttention
from .based import BasedLinearAttention
from .delta_net import DeltaNet
from .gla import GatedLinearAttention
from .hgrn import HGRNAttention
from .hgrn2 import HGRN2Attention
from .linear_attn import LinearAttention
from .multiscale_retention import MultiScaleRetention
from .rebased import ReBasedLinearAttention
from .rwkv6 import RWKV6Attention
__all__ = [
'ABCAttention',
'BasedLinearAttention',
'DeltaNet',
'GatedLinearAttention',
'HGRNAttention',
'HGRN2Attention',
'LinearAttention',
'MultiScaleRetention',
'ReBasedLinearAttention',
'RWKV6Attention'
]

195
finetune/lora/v6/fla/layers/abc.py vendored Normal file
View File

@ -0,0 +1,195 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import warnings
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from transformers.cache_utils import Cache
from fla.modules import (FusedRMSNormSwishGate, RMSNorm, RotaryEmbedding,
ShortConvolution)
from fla.modules.activations import swiglu, swish
from fla.modules.convolution import proj_then_conv1d
from fla.ops.abc.chunk import chunk_abc
class ABCAttention(nn.Module):
def __init__(
self,
hidden_size: int = 1024,
expand_k: float = 0.5,
expand_v: float = 1.0,
num_heads: int = 4,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
num_slots: Optional[int] = None,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
gate_low_rank_dim: int = 16,
gate_logit_normalizer: int = 16,
use_input_gate: bool = False,
use_output_gate: bool = True,
use_norm: bool = True,
clamp_min: Optional[float] = -32,
clamp_max: Optional[float] = 32,
layer_idx: Optional[int] = None,
**kwargs
) -> ABCAttention:
super().__init__()
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.key_dim = int(self.hidden_size * self.expand_k)
self.value_dim = int(self.hidden_size * self.expand_v)
self.head_k_dim = self.key_dim // self.num_heads
self.head_v_dim = self.value_dim // self.num_heads
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.gate_low_rank_dim = gate_low_rank_dim
self.gate_logit_normalizer = gate_logit_normalizer
self.use_input_gate = use_input_gate
self.use_output_gate = use_output_gate
self.use_norm = use_norm
if num_slots is None:
num_slots = self.head_k_dim
self.num_slots = num_slots
self.norm_eps = norm_eps
self.clamp_min = clamp_min
self.clamp_max = clamp_max
self.layer_idx = layer_idx
if layer_idx is None:
warnings.warn(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
if use_output_gate:
self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
if self.use_norm:
if self.use_output_gate:
self.g_norm = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
else:
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
if self.use_rope:
self.rotary = RotaryEmbedding(self.head_k_dim)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if self.use_short_conv:
if self.share_conv_kernel:
hidden_states = self.h_conv1d(hidden_states)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
q = proj_then_conv1d(hidden_states, self.q_proj.weight, self.q_conv1d.weight, self.q_conv1d.bias)
k = proj_then_conv1d(hidden_states, self.k_proj.weight, self.k_conv1d.weight, self.k_conv1d.bias)
v = proj_then_conv1d(hidden_states, self.v_proj.weight, self.v_conv1d.weight, self.v_conv1d.bias)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
if self.use_input_gate:
q, k, v = map(lambda x: swish(x), (q, k, v))
if self.use_rope:
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
k = rearrange(k, '... (h d) -> ... h d', h=self.num_heads)
seqlen_offset = 0
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
q, k = self.rotary(q, k, seqlen_offset)
q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads)
k = rearrange(k, 'b n h d -> b h n d', h=self.num_heads)
else:
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
# [batch_size, n_heads, seq_len, num_slots]
s = rearrange(self.s_proj(hidden_states), 'b t (h m) -> b h t m', h=self.num_heads)
s = s.clamp_(self.clamp_min, self.clamp_max)
last_state = past_key_values[self.layer_idx] if use_cache else None
o, last_state = chunk_abc(q, k, v, s, initial_state=last_state, output_final_state=use_cache)
if past_key_values is not None and last_state is not None:
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = rearrange(o, 'b h t d -> b t h d')
if self.use_norm and not self.use_output_gate:
o = self.g_norm(o)
elif self.use_output_gate:
g = rearrange(self.g_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_heads)
o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
o = rearrange(o, 'b t h d -> b t (h d)')
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots),
param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim))
return state
def state_size(self, sequence_length: int = 2048):
return self.num_heads * self.key_dim * self.head_v_dim

126
finetune/lora/v6/fla/layers/based.py vendored Normal file
View File

@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
"""
Linear attention in Based.
https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
"""
import torch
import torch.nn as nn
from einops import rearrange
from fla.modules.feature_map import TaylorFeatureMap
from fla.ops.based import parallel_based
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
class BasedLinearAttention(nn.Module):
def __init__(
self,
hidden_size: int,
l_max: int = 2048,
feature_dim: int = 16,
num_key_value_heads: int = 12,
num_heads: int = 12,
feature_name: str = "taylor_exp",
eps: float = 1e-12,
causal: bool = True,
mode: str = "parallel",
):
super().__init__()
self.hidden_size
self.l_max = l_max
self.mode = mode
assert self.mode in ["fused_chunk", "parallel", 'chunk']
# linear attention
self.feature_name = feature_name
self.feature_dim = feature_dim
self.num_key_value_heads = num_key_value_heads
self.num_heads = num_heads
self.head_dim = self.hidden_size // self.num_key_value_heads
self.causal = causal
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.dropout = nn.Identity()
self.feature_map = TaylorFeatureMap(feature_dim)
self.eps = eps
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, hidden_states: torch.Tensor, **kwargs):
mode = self.mode
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
if mode == "fused_chunk":
q, k = self.feature_map(q), self.feature_map(k)
o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'chunk':
q, k = self.feature_map(q), self.feature_map(k)
o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'parallel':
assert q.shape[-1] <= 128
o = parallel_based(q, k, v, True, True)
o = rearrange(o, "b h l d -> b l (h d)")
o = self.o_proj(o)
o = self.dropout(o)
return o
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
"""
x (torch.Tensor): tensor of shape (b, d, l)
y (torch.Tensor): tensor of shape (b, d, l)
"""
# hidden_states = hidden_states.transpose(1, 2)
b, l, _ = hidden_states.size()
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Linear attention
q, k = self.feature_map(q), self.feature_map(k)
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
# Compute attention
if self.causal:
y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
else:
y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
y = rearrange(y, 'b h l d -> b l (h d)')
y = self.o_proj(y.to(hidden_states.dtype))
y = self.dropout(y)
return y.to(hidden_states.dtype)
if __name__ == '__main__':
batch = 4
seq_len = 1024
hidden_size = 1024
dtype = torch.float32
x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
model = BasedLinearAttention(hidden_size, mode='chunk').to(dtype).cuda()
y = model(x)
y.backward(dy, retain_graph=True)
x_grad, x.grad = x.grad, None
y2 = model.forward_reference(x)
y2.backward(dy)
assert y.allclose(y2, 0, 1e-4), breakpoint()
assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
print("Pass")

254
finetune/lora/v6/fla/layers/delta_net.py vendored Normal file
View File

@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
# Sect4.2 of Linear Transformers Are Secretly Fast Weight Programmers https://arxiv.org/abs/2102.11174
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from transformers.cache_utils import Cache
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution, LayerNorm
from fla.modules.rotary import RotaryEmbedding
from fla.ops.delta_rule import (fused_chunk_delta_rule,
fused_recurrent_linear_attn_delta_rule,
chunk_delta_rule)
from torch.nn import functional as F
def simple_norm(x):
return (F.normalize(x, dim=-1) * x.shape[-1] ** 0.5).to(x)
# @torch.jit.script
def elu_p1(x):
return (F.elu(x, 1., False) + 1.).to(x)
# @torch.jit.script
def sum_norm(x):
return (x / x.sum(-1, keepdim=True)).to(x)
# @torch.jit.script
def elu_norm(x):
dtype = x.dtype
x = F.elu(x, 1., False) + 1.
return (x / x.sum(-1, keepdim=True)).to(dtype)
# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
class DeltaNet(nn.Module):
def __init__(
self,
d_model: int = None,
hidden_size: int = 1024,
expand_k: float = 1.0,
expand_v: float = 1.0,
num_heads: int = 4,
mode: str = 'fused_chunk',
chunk_size: int = 16,
use_beta: bool = True,
use_gate: bool = True,
use_rope: bool = False,
use_output_norm: bool = True,
use_elu: bool = False,
use_short_conv: bool = True,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = False,
layer_idx: int = None,
qk_activation: str = 'silu',
qk_norm: str = None,
save_memory: str = False,
**kwargs
) -> DeltaNet:
super().__init__()
self.mode = mode
self.qk_activation = qk_activation
self.qk_norm = qk_norm
assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
assert self.qk_norm in ['l2', 'sum']
if d_model is not None:
hidden_size = d_model
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.chunk_size = chunk_size
self.use_gate = use_gate
self.use_output_norm = use_output_norm
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.layer_idx = layer_idx
self.silu = torch.nn.SiLU()
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
self.use_beta = use_beta
self.use_elu = use_elu
if self.use_beta:
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation=None)
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
if use_gate:
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
if self.use_gate:
self.norm = FusedRMSNormSwishGate(self.head_v_dim)
else:
self.norm = RMSNorm(self.head_v_dim)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# change to inference mode.
mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if attention_mask is not None:
if attention_mask.shape[-1] != hidden_states.shape[-2]:
attention_mask = attention_mask[:, -1:]
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_k = last_state[1] if use_cache else None
conv_state_v = last_state[2] if use_cache else None
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = self.q_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
k = self.k_conv1d(k, attention_mask, conv_state_k)
v = self.v_conv1d(v, attention_mask, conv_state_v)
else:
q = (self.q_proj(hidden_states))
k = (self.k_proj(hidden_states))
v = self.silu(self.v_proj(hidden_states))
# dealing with left-padding
if attention_mask is not None:
v = v.mul_(attention_mask.unsqueeze(-1))
q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v))
if self.qk_activation != 'silu':
if self.qk_activation == 'relu':
q, k = q.relu(), k.relu()
elif self.qk_activation == 'elu':
q, k = elu_p1(q), elu_p1(k)
elif self.qk_activation == 'identity':
pass
else:
raise NotImplementedError
if self.qk_norm is not None:
if self.qk_norm == 'l2':
k = torch.nn.functional.normalize(k, dim=-1, p=2).to(v) #auto mixed precision type transfer is annoying.
q = torch.nn.functional.normalize(q, dim=-1, p=2).to(v)
elif self.qk_norm == 'sum':
q = sum_norm(q).to(v)
k = sum_norm(k).to(v)
if self.use_beta:
beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid()
else:
beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
state = past_key_values[self.layer_idx][-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_linear_attn_delta_rule(q, k, v, beta, state, output_final_state=use_cache)
elif mode == 'fused_chunk':
assert self.chunk_size in [16, 32, 64]
o, recurrent_state = fused_chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
elif mode == 'chunk':
assert self.chunk_size in [16, 32, 64]
o, recurrent_state = chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
state = (conv_state, recurrent_state)
else:
state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
else:
state = (recurrent_state,)
past_key_values.update(state, self.layer_idx)
o = rearrange(o, 'b h l d -> b l h d')
if self.use_gate:
g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads)
o = self.norm(o, g)
else:
o = self.norm(o)
o = rearrange(o, 'b l h d -> b l (h d)')
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
# for q/k/v each
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.value_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
return state

234
finetune/lora/v6/fla/layers/gated_abc.py vendored Normal file
View File

@ -0,0 +1,234 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import warnings
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers.cache_utils import Cache
from fla.modules import (FusedRMSNormSwishGateLinear, RMSNormLinear,
RotaryEmbedding, ShortConvolution)
from fla.modules.activations import ACT2FN, swiglu_linear, swish
from fla.ops.abc.chunk_gate import chunk_gated_abc
class GatedABCAttention(nn.Module):
def __init__(
self,
hidden_size: int = 1024,
expand_k: float = 1.,
expand_v: float = 1.,
num_heads: int = 4,
num_kv_heads: Optional[int] = None,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
num_slots: Optional[int] = None,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
gate_low_rank_dim: Optional[int] = None,
gate_logit_normalizer: int = 16,
feature_map: str = 'swish',
use_rope: bool = False,
use_output_gate: bool = False,
use_norm: bool = True,
layer_idx: Optional[int] = None,
**kwargs
) -> GatedABCAttention:
super().__init__()
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.key_dim_per_group = self.key_dim // self.num_kv_groups
self.value_dim_per_group = self.value_dim // self.num_kv_groups
self.head_k_dim = self.key_dim // self.num_heads
self.head_v_dim = self.value_dim // self.num_heads
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
if gate_low_rank_dim is None:
gate_low_rank_dim = self.hidden_size // 16
self.gate_low_rank_dim = gate_low_rank_dim
self.gate_logit_normalizer = gate_logit_normalizer
self.feature_map = feature_map
self.use_rope = use_rope
self.use_output_gate = use_output_gate
self.use_norm = use_norm
if num_slots is None:
num_slots = self.head_k_dim
self.num_slots = num_slots
self.layer_idx = layer_idx
if layer_idx is None:
warnings.warn(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
if use_output_gate:
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
if self.use_norm:
if self.use_output_gate:
self.g_norm = FusedRMSNormSwishGateLinear(self.hidden_size, elementwise_affine, norm_eps)
else:
self.g_norm = RMSNormLinear(self.hidden_size, elementwise_affine, norm_eps)
self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
if self.use_rope:
self.rotary = RotaryEmbedding(self.head_k_dim)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
lower_bound: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_k = last_state[1] if use_cache else None
conv_state_v = last_state[2] if use_cache else None
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
k = self.k_conv1d(k, attention_mask, conv_state_k)
v = self.v_conv1d(v, attention_mask, conv_state_v)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
f = self.f_proj(hidden_states)
if self.use_rope:
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
seqlen_offset = 0
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
q, k = self.rotary(q, k, seqlen_offset)
q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads)
k = rearrange(k, 'b n h d -> b h n d', h=self.num_kv_heads)
else:
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
if self.num_kv_groups > 1:
k = repeat(k, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups)
else:
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_kv_heads)
if self.num_kv_groups > 1:
v = repeat(v, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups)
f = repeat(f, 'b n (h m) -> b (h g) n m', h=self.num_kv_heads, g=self.num_kv_groups)
else:
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_kv_heads)
f = rearrange(f, 'b n (h m) -> b h n m', h=self.num_kv_heads)
if self.feature_map is not None:
q, k, v = map(lambda x: ACT2FN[self.feature_map](x), (q, k, v))
f = F.logsigmoid(f) / self.gate_logit_normalizer
s = (1 - f.exp()).to(f.dtype)
# dealing with left-padding
if attention_mask is not None:
s = s.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1))
v = v.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1))
recurrent_state = last_state[-2:] if use_cache else None
o, recurrent_state = chunk_gated_abc(q, k, v, s, f,
initial_state=recurrent_state,
output_final_state=use_cache)
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state,) + recurrent_state
else:
last_state = (conv_state_q, conv_state_k, conv_state_v) + recurrent_state
else:
last_state = recurrent_state
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = rearrange(o, 'b h t d -> b t (h d)')
if self.use_norm and not self.use_output_gate:
o = swish(o)
o = self.g_norm(o, self.o_proj.weight, self.o_proj.bias)
elif self.use_output_gate and not self.use_norm:
o = swiglu_linear(self.g_proj(hidden_states), o, self.o_proj.weight, self.o_proj.bias)
elif self.use_output_gate and self.use_norm:
o = self.g_norm(o, self.g_proj(hidden_states), self.o_proj.weight, self.o_proj.bias)
else:
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.value_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots),
param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim))
return state
def state_size(self, sequence_length: int = 2048):
return self.num_heads * self.key_dim * self.head_v_dim

268
finetune/lora/v6/fla/layers/gla.py vendored Normal file
View File

@ -0,0 +1,268 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers.cache_utils import Cache
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
from fla.modules.activations import ACT2FN
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
class GatedLinearAttention(nn.Module):
r"""
The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
Args:
mode (str, Optional):
Which GLA kernel to use.
Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
Default: `chunk`.
hidden_size (int, Optional):
The hidden size of the input. Default: 1024.
expand_k (float, Optional):
The expansion ratio for the key dim. Default: 0.5.
expand_v (float, Optional):
The expansion ratio for the value dim. Default: 1.0.
num_heads (int, Optional):
The number of heads. Default: 4.
num_kv_heads (int, Optional):
The number of key/value heads, used for MQA. Default: None.
feature_map (str, Optional):
Feature map function applied to queries/keys. Default: None.
use_short_conv (bool, Optional):
Whether to use short convolutions. Default: `False`.
conv_size (int, Optional):
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
conv_bias (bool, Optional):
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
share_conv_kernel (bool, Optional):
Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`.
use_output_gate (bool, Optional):
Whether to use output gate. Default: `True`.
gate_fn (str, Optional):
The activation function for the output gate. Default: `swish`.
elementwise_affine (bool, Optional):
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
norm_eps (float, Optional):
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
gate_logit_normalizer (int, Optional):
The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
gate_low_rank_dim (int, Optional):
The low rank dim for the gate projection. Default: 16.
clamp_min (float, Optional):
The minimum value for the gate logits. Default: None.
fuse_norm (bool, Optional):
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
layer_idx (int, Optional):
The index of the layer. Default: None.
"""
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
expand_k: float = 0.5,
expand_v: float = 1.0,
num_heads: int = 4,
num_kv_heads: Optional[int] = None,
feature_map: Optional[str] = None,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
use_output_gate: bool = True,
gate_fn: str = 'swish',
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
gate_logit_normalizer: int = 16,
gate_low_rank_dim: int = 16,
clamp_min: Optional[float] = None,
fuse_norm: bool = True,
layer_idx: int = None,
) -> GatedLinearAttention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.use_output_gate = use_output_gate
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.key_dim_per_group = self.key_dim // self.num_kv_groups
self.value_dim_per_group = self.value_dim // self.num_kv_groups
self.clamp_min = clamp_min
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
if self.use_output_gate:
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
if gate_fn == 'swish' and fuse_norm and use_output_gate:
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
self.fuse_norm_and_gate = True
else:
self.fuse_norm_and_gate = False
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
self.gate_fn = ACT2FN[gate_fn]
self.gate_logit_normalizer = gate_logit_normalizer
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_k = last_state[1] if use_cache else None
conv_state_v = last_state[2] if use_cache else None
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
k = self.k_conv1d(k, attention_mask, conv_state_k)
v = self.v_conv1d(v, attention_mask, conv_state_v)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
gk = self.gk_proj(hidden_states)
if self.feature_map_fn is not None:
q, k = map(self.feature_map_fn, (q, k))
# dealing with left-padding
if attention_mask is not None:
v = v.mul_(attention_mask.unsqueeze(-1))
q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads)
if self.num_kv_groups > 1:
k, v, gk = (repeat(x, 'b l (h d) -> b (h g) l d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v, gk))
else:
k, v, gk = (rearrange(x, 'b l (h d) -> b h l d', h=self.num_kv_heads) for x in (k, v, gk))
gk = F.logsigmoid(gk) / self.gate_logit_normalizer
if self.clamp_min is not None:
gk = torch.clamp_min(gk, self.clamp_min)
recurrent_state = last_state[-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'fused_chunk':
o, recurrent_state = fused_chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'chunk':
o, recurrent_state = chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state, recurrent_state)
else:
last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
else:
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = rearrange(o, 'b h l d -> b l h d')
if self.use_output_gate:
g = self.g_proj(hidden_states)
if self.fuse_norm_and_gate:
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
o = self.g_norm_swish_gate(o, g)
o = rearrange(o, 'b l h d -> b l (h d)')
else:
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
o = o * self.gate_fn(g)
else:
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.value_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.key_dim * self.head_v_dim
for module in self.children():
if isinstance(module, ShortConvolution):
state_size += module.state_size
return state_size

165
finetune/lora/v6/fla/layers/hgrn.py vendored Normal file
View File

@ -0,0 +1,165 @@
# -*- coding: utf-8 -*-
# "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.cache_utils import Cache
from fla.modules import FusedRMSNormSwishGate, ShortConvolution
from fla.modules.activations import swiglu
from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
class HGRNAttention(nn.Module):
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
num_heads: Optional[int] = None,
expand_ratio: Optional[int] = 1,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
layer_idx: int = None
) -> HGRNAttention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
self.num_heads = num_heads
self.expand_ratio = expand_ratio
self.input_dim = int(hidden_size * expand_ratio)
self.head_dim = self.input_dim // self.num_heads
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.hidden_size % num_heads == 0, f"hidden size must be divisible by num_heads of {num_heads}"
self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.g_norm = FusedRMSNormSwishGate(self.input_dim, elementwise_affine, norm_eps)
self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
lower_bound: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
i = self.i_proj(hidden_states)
f = self.f_proj(hidden_states)
else:
conv_state_i = last_state[2] if use_cache else None
conv_state_f = last_state[1] if use_cache else None
i = self.i_conv1d(self.i_proj(hidden_states), attention_mask, conv_state_i)
f = self.f_conv1d(self.f_proj(hidden_states), attention_mask, conv_state_f)
else:
i = self.i_proj(hidden_states)
f = self.f_proj(hidden_states)
# the lower bound for the first layer is zero
if lower_bound is None or self.layer_idx == 0:
i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
else:
g = lower_bound + (1 - lower_bound) * f.sigmoid()
i, f = swiglu(i, 1 - g), g.log()
# dealing with left-padding
if attention_mask is not None:
i = i.mul_(attention_mask.unsqueeze(-1))
i, f = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (i, f))
recurrent_state = last_state[-1] if use_cache else None
if mode == 'chunk':
o, recurrent_state = chunk_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state, recurrent_state)
else:
last_state = (conv_state_i, conv_state_f, recurrent_state)
else:
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, i.shape[2])
o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)'))
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),
param.new_zeros(batch_size, self.hidden_size, self.conv_size),
param.new_zeros(batch_size, self.hidden_size, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.hidden_size
for module in self.children():
if isinstance(module, ShortConvolution):
state_size += module.state_size
return state_size

186
finetune/lora/v6/fla/layers/hgrn2.py vendored Normal file
View File

@ -0,0 +1,186 @@
# -*- coding: utf-8 -*-
# "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.cache_utils import Cache
from fla.modules import RMSNorm, ShortConvolution
from fla.modules.activations import swish
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
class HGRN2Attention(nn.Module):
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
num_heads: Optional[int] = None,
expand_ratio: Optional[int] = 128,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
layer_idx: int = None
) -> HGRN2Attention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
if expand_ratio is None and num_heads is not None:
expand_ratio = hidden_size // num_heads
elif expand_ratio is not None and num_heads is None:
num_heads = hidden_size // expand_ratio
else:
raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
self.num_heads = num_heads
self.expand_ratio = expand_ratio
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.forget_dim = int(self.num_heads * self.expand_ratio)
self.input_dim = hidden_size
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
self.head_f_dim = self.expand_ratio
self.head_i_dim = self.hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu')
self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu')
self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, norm_eps)
self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
lower_bound: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
f = self.f_proj(hidden_states)
i = self.i_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_f = last_state[1] if use_cache else None
conv_state_i = last_state[2] if use_cache else None
q = self.q_proj(hidden_states)
f = self.f_proj(hidden_states)
i = self.i_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
f = self.f_conv1d(f, attention_mask, conv_state_f)
i = self.i_conv1d(i, attention_mask, conv_state_i)
else:
q = self.q_proj(hidden_states)
f = self.f_proj(hidden_states)
i = self.i_proj(hidden_states)
# dealing with left-padding
if attention_mask is not None:
i = i.mul_(attention_mask.unsqueeze(-1))
q = swish(q)
# the lower bound for the first layer is zero
if lower_bound is None or self.layer_idx == 0:
k, g = 1 - f.sigmoid(), F.logsigmoid(f)
else:
g = lower_bound + (1 - lower_bound) * f.sigmoid()
k, g = 1 - g, g.log()
q, k, i, g = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, i, g))
recurrent_state = last_state[-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'fused_chunk':
o, recurrent_state = fused_chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'chunk':
o, recurrent_state = chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state, recurrent_state)
else:
last_state = (conv_state_q, conv_state_f, conv_state_i, recurrent_state)
else:
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = self.g_norm(rearrange(o, 'b h l d -> b l (h d)'))
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.forget_dim, self.conv_size),
param.new_zeros(batch_size, self.forget_dim, self.conv_size),
param.new_zeros(batch_size, self.input_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_f_dim, self.head_i_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.forget_dim * self.head_i_dim
for module in self.children():
if isinstance(module, ShortConvolution):
state_size += module.state_size
return state_size

View File

@ -0,0 +1,156 @@
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from fla.modules import RMSNorm
from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap,
HedgehogFeatureMap, T2RFeatureMap)
from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn,
fused_recurrent_linear_attn)
class LinearAttention(nn.Module):
def __init__(
self,
hidden_size: str = 1024,
expand_k: int = 1.0,
expand_v: int = 1.0,
num_heads: int = 8,
mode: str = 'chunk',
feature_map: str = 'elementwise_product',
tie_feature_map_qk: bool = False,
output_norm: str = 'rmsnorm',
norm_q: bool = False,
norm_k: bool = False,
# standard linear attention normalization
do_feature_map_norm: bool = False,
elementwise_affine: bool = True,
norm_eps: float = 1e-5,
**kwargs,
):
super().__init__()
assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp',
'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`."
assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`."
self.hidden_size
self.mode = mode
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.num_heads = num_heads
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
if feature_map == 'hedgehog':
if tie_feature_map_qk:
self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
else:
self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 't2r':
if tie_feature_map_qk:
self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
else:
self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 'elementwise_product':
if tie_feature_map_qk:
self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
else:
self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 'dpfp':
self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 'elu':
def elu(x):
return F.elu(x) + 1
self.feature_map_q = elu
self.feature_map_k = elu
elif feature_map == 'relu':
self.feature_map_q = nn.ReLU()
self.feature_map_k = nn.ReLU()
elif feature_map == 'identity':
self.feature_map_q = nn.Identity()
self.feature_map_k = nn.Identity()
else:
raise NotImplementedError
self.do_feature_map_norm = do_feature_map_norm
if output_norm == 'rmsnorm':
self.norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
elif output_norm == 'identity':
self.norm = nn.Identity()
else:
raise NotImplementedError
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
self.norm_q = norm_q
self.norm_k = norm_k
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, x):
mode = self.mode
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
q = self.feature_map_q(q)
k = self.feature_map_k(k)
if self.norm_q:
q = q / (q.sum(-1, keepdim=True) + 1e-4)
if self.norm_k:
k = k / (k.sum(-1, keepdim=True) + 1e-4)
if mode == 'chunk':
o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
elif mode == 'fused_chunk':
o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
elif mode == 'fused_recurrent':
o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
else:
raise NotImplementedError
o = self.norm(o)
o = rearrange(o, 'b h n d -> b n (h d)')
o = self.o_proj(o)
return o
if __name__ == '__main__':
import torch
batch = 4
seq_len = 1024
hidden_size = 1024
x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
model = LinearAttention(hidden_size, feature_map='dplp').to(torch.bfloat16).cuda()
y = model(x)
print(y.shape)
y.sum().backward()
print(x.grad.shape)

View File

@ -0,0 +1,271 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange, repeat
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
from fla.modules.rotary import RotaryEmbedding
from fla.ops.retention import (chunk_retention, fused_chunk_retention,
fused_recurrent_retention, parallel_retention)
class MultiScaleRetention(nn.Module):
r"""
The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa
Args:
mode (str, Optional):
Which Retention kernel to use.
Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
Default: `fused_chunk`.
hidden_size (int, Optional):
The hidden size of the input. Default: 1024.
expand_k (float, Optional):
The expansion ratio for the key dim. Default: 1.0.
expand_v (float, Optional):
The expansion ratio for the value dim. Default: 2.0.
num_heads (int, Optional):
The number of heads. Default: 8.
num_kv_heads (int, Optional):
The number of key/value heads, used for MQA. Default: None.
feature_map (str, Optional):
Feature map function applied to queries/keys. Default: None.
use_short_conv (bool, Optional):
Whether to use short convolutions. Default: `False`.
conv_size (int, Optional):
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
conv_bias (bool, Optional):
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
share_conv_kernel (bool, Optional):
Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`.
use_output_gate (bool, Optional):
Whether to use output gate. Default: `True`.
gate_fn (str, Optional):
The activation function for the output gate. Default: `swish`.
elementwise_affine (bool, Optional):
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
norm_eps (float, Optional):
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
fuse_norm (bool, Optional):
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
layer_idx (int, Optional):
The index of the layer. Default: None.
"""
def __init__(
self,
mode: str = 'fused_chunk',
hidden_size: int = 1024,
expand_k: float = 1.0,
expand_v: float = 2.0,
num_heads: int = 8,
num_kv_heads: Optional[int] = None,
feature_map: Optional[str] = None,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
use_output_gate: bool = True,
gate_fn: str = 'swish',
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
fuse_norm: bool = True,
layer_idx: int = None,
**kwargs
) -> MultiScaleRetention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.use_output_gate = use_output_gate
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.key_dim_per_group = self.key_dim // self.num_kv_groups
self.value_dim_per_group = self.value_dim // self.num_kv_groups
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
if self.use_output_gate:
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
if gate_fn == 'swish' and fuse_norm and use_output_gate:
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
self.fuse_norm_and_gate = True
else:
self.fuse_norm_and_gate = False
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
self.gate_fn = ACT2FN[gate_fn]
# TODO: fix this issue
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
# Ideally, we would want to support arbitrary d_head_qk
assert self.head_qk_dim <= 256, "head_qk_dim must be less than or equal to 256"
self.rotary = RotaryEmbedding(dim=self.head_qk_dim)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_k = last_state[1] if use_cache else None
conv_state_v = last_state[2] if use_cache else None
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
k = self.k_conv1d(k, attention_mask, conv_state_k)
v = self.v_conv1d(v, attention_mask, conv_state_v)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# dealing with left-padding
if attention_mask is not None:
v = v.mul_(attention_mask.unsqueeze(-1))
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
if self.feature_map_fn is not None:
q, k = map(self.feature_map_fn, (q, k))
seqlen_offset, max_seqlen = 0, None
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
max_seqlen = q.shape[1] + seqlen_offset
if attention_mask is not None:
# to deliminate the offsets of padding tokens
seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
max_seqlen = q.shape[1] + max(seqlen_offset)
q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
q = q.transpose(1, 2)
if self.num_kv_groups > 1:
k = repeat(k, 'b t h d -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
v = repeat(v, 'b t (h d) -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
else:
k, v = rearrange(k, 'b t h d -> b h t d'), rearrange(v, 'b t (h d) -> b h t d', h=self.num_kv_heads)
state = last_state[-1] if use_cache else None
if mode == 'chunk':
o, recurrent_state = chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache)
elif mode == 'fused_chunk':
o, recurrent_state = fused_chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache)
elif mode == 'parallel':
o, recurrent_state = parallel_retention(q, k, v, initial_state=state, output_final_state=use_cache)
elif mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_retention(q, k, v, initial_state=state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state, recurrent_state)
else:
last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
else:
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = rearrange(o, 'b h l d -> b l h d')
if self.use_output_gate:
g = self.g_proj(hidden_states)
if self.fuse_norm_and_gate:
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
o = self.g_norm_swish_gate(o, g)
o = rearrange(o, 'b l h d -> b l (h d)')
else:
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
o = o * self.gate_fn(g)
else:
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.value_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.key_dim * self.head_v_dim
for module in self.children():
if isinstance(module, ShortConvolution):
state_size += module.state_size
return state_size

137
finetune/lora/v6/fla/layers/rebased.py vendored Normal file
View File

@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
"""
https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from fla.modules.feature_map import RebasedFeatureMap
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
from fla.ops.rebased import parallel_rebased
class ReBasedLinearAttention(nn.Module):
def __init__(
self,
hidden_size: int,
l_max: int = 2048,
feature_dim: int = 16,
num_key_value_heads: int = 16,
num_heads: int = 16,
use_gamma: Optional[bool] = True,
use_beta: Optional[bool] = True,
normalize: Optional[bool] = True,
causal: bool = True,
eps: float = 1e-5,
mode: str = "parallel",
layer_idx: Optional[int] = None,
**kwargs
) -> ReBasedLinearAttention:
super().__init__()
self.hidden_size = hidden_size
self.l_max = l_max
self.mode = mode
assert self.mode in ["fused_chunk", "parallel", 'chunk']
# linear attention
self.feature_dim = feature_dim
self.num_key_value_heads = num_key_value_heads
self.num_heads = num_heads
self.head_dim = self.hidden_size // self.num_key_value_heads
self.use_gamma = use_gamma
self.use_beta = use_beta
self.normalize = normalize
self.causal = causal
self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.dropout = nn.Identity()
self.eps = eps
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, hidden_states: torch.Tensor, **kwargs):
mode = self.mode
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
if mode == "fused_chunk":
o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'chunk':
o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'parallel':
assert q.shape[-1] <= 128
o = parallel_rebased(q, k, v, self.eps, True, True)
o = rearrange(o, "b h l d -> b l (h d)")
o = self.o_proj(o)
o = self.dropout(o)
return o
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
"""
x (torch.Tensor): tensor of shape (b, d, l)
y (torch.Tensor): tensor of shape (b, d, l)
"""
# hidden_states = hidden_states.transpose(1, 2)
b, l, _ = hidden_states.size()
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Linear attention
q, k = self.feature_map(q), self.feature_map(k)
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
# Compute attention
if self.causal:
y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
else:
y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
y = rearrange(y, 'b h l d -> b l (h d)')
y = self.o_proj(y.to(hidden_states.dtype))
y = self.dropout(y)
return y.to(hidden_states.dtype)
if __name__ == '__main__':
batch = 4
seq_len = 1024
hidden_size = 1024
dtype = torch.float32
x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
model = ReBasedLinearAttention(hidden_size=hidden_size, mode='parallel').to(dtype).cuda()
y = model(x)
y.backward(dy, retain_graph=True)
x_grad, x.grad = x.grad, None
print(model.mode)
model.mode = 'fused_chunk'
y2 = model(x)
print(model.mode)
y2.backward(dy)
# assert y.allclose(y2, 0, 1e-4), breakpoint()
# assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
print("Pass")

264
finetune/lora/v6/fla/layers/rwkv6.py vendored Normal file
View File

@ -0,0 +1,264 @@
# -*- coding: utf-8 -*-
# "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from fla.modules import FusedLayerNormSwishGate, LayerNorm
from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
class RWKV6Attention(nn.Module):
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
expand_k: float = 0.5,
expand_v: float = 1.0,
num_heads: int = 4,
gate_fn: str = 'swish',
proj_low_rank_dim: int = 32,
gate_low_rank_dim: int = 64,
fuse_norm: bool = True,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
layer_idx: int = None,
**kwargs
) -> RWKV6Attention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.proj_low_rank_dim = proj_low_rank_dim
self.gate_low_rank_dim = gate_low_rank_dim
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.x_proj = nn.Sequential(
LerpLinear(hidden_size, proj_low_rank_dim * 5),
nn.Tanh(),
nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=True)
)
self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_qk_dim))
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
if gate_fn == 'swish' and fuse_norm:
self.g_norm_swish_gate = FusedLayerNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
self.fuse_norm_and_gate = True
else:
self.fuse_norm_and_gate = False
self.g_norm = LayerNorm(self.head_v_dim, elementwise_affine, norm_eps)
self.gate_fn = ACT2FN[gate_fn]
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
if isinstance(module, nn.Parameter):
nn.init.xavier_uniform_(module, gain=2 ** -2.5)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
batch_size, seq_len, hidden_size = hidden_states.size()
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
delta = self.time_shift(hidden_states) - hidden_states
x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
r, w, k, v, g = torch.einsum('b l n r, n r d-> b l n d',
self.x_proj[1](x),
self.x_proj[2].weight.view(5, -1, hidden_size)).unbind(-2)
r = self.r_proj(hidden_states, r, delta)
w = self.w_proj(hidden_states, w, delta)
k = self.k_proj(hidden_states, k, delta)
v = self.v_proj(hidden_states, v, delta)
g = self.g_proj(hidden_states, g, delta)
# dealing with left-padding
if attention_mask is not None:
v = v.mul_(attention_mask.unsqueeze(-1))
r, w, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (r, w, k, v))
w = -torch.exp(w)
u = self.bonus
last_state = past_key_values[self.layer_idx] if use_cache else None
state = last_state[-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache)
elif mode == 'chunk':
o, recurrent_state = chunk_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
past_key_values.update((recurrent_state,), self.layer_idx, r.shape[2])
o = rearrange(o, 'b h l d -> b l h d')
if self.fuse_norm_and_gate:
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
o = self.g_norm_swish_gate(o, g)
o = rearrange(o, 'b l h d -> b l (h d)')
else:
o = self.g_norm(o)
o = rearrange(o, 'b l h d -> b l (h d)')
o = o * self.gate_fn(g)
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.key_dim * self.head_v_dim
return state_size
class LoRA(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
low_rank_dim: int,
bias: Optional[bool] = True
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.low_rank_dim = low_rank_dim
self.bias = bias
self.lora = nn.Sequential(
nn.Linear(input_dim, low_rank_dim, bias=False),
nn.Tanh(),
nn.Linear(low_rank_dim, output_dim, bias=bias)
)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}("
s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
if not self.bias:
s += f", bias={self.bias}"
s += ")"
return s
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lora(x)
class LerpLinear(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
low_rank_dim: Optional[int] = None
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.low_rank_dim = low_rank_dim
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
if low_rank_dim is None:
self.linear = nn.Linear(input_dim, output_dim, bias=False)
else:
self.linear = LoRA(input_dim, output_dim, low_rank_dim)
self.mu = nn.Parameter(torch.zeros(input_dim))
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
if self.low_rank_dim is not None:
s += f", low_rank_dim={self.low_rank_dim}"
s += ")"
return s
def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
if delta is None:
shifted = self.time_shift(x)
if len(shifted.shape) == 2:
shifted = shifted.unsqueeze(1)
delta = shifted - x
return self.linear(x + delta * self.mu)
class DDLerpLinear(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
low_rank_dim: Optional[int] = None
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.low_rank_dim = low_rank_dim
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
if low_rank_dim is None:
self.linear = nn.Linear(input_dim, output_dim, bias=False)
else:
self.linear = LoRA(input_dim, output_dim, low_rank_dim)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
if self.low_rank_dim is not None:
s += f", low_rank_dim={self.low_rank_dim}"
s += ")"
return s
def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
if delta is None:
shifted = self.time_shift(x)
if len(shifted.shape) == 2:
shifted = shifted.unsqueeze(1)
delta = shifted - x
return self.linear(x + delta * mu)

View File

@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.activations import ACT2FN
from fla.modules import FusedRMSNormSwishGate, RMSNorm
from fla.ops.simple_gla import chunk_simple_gla
class SimpleGatedLinearAttention(nn.Module):
r"""
The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
Args:
mode (str, Optional):
Which GLA kernel to use.
Currently available: `chunk`.
Default: `chunk`.
hidden_size (int, Optional):
The hidden size of the input. Default: 1024.
expand_k (float, Optional):
The expansion ratio for the key dim. Default: 0.5.
expand_v (float, Optional):
The expansion ratio for the value dim. Default: 1.0.
num_heads (int, Optional):
The number of heads. Default: 4.
gate_fn (str, Optional):
The activation function for the output gate. Default: `swish`.
elementwise_affine (bool, Optional):
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
norm_eps (float, Optional):
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
gate_logit_normalizer (int, Optional):
The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
fuse_norm (bool, Optional):
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
layer_idx (int, Optional):
The index of the layer. Default: None.
"""
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
expand_k: float = 1.0,
expand_v: float = 2.0,
num_heads: int = 4,
gate_fn: str = 'swish',
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
gate_logit_normalizer: int = 16,
fuse_norm: bool = True,
**kwargs
) -> SimpleGatedLinearAttention:
super().__init__()
self.hidden_size = hidden_size
self.mode = mode
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
assert mode in ['chunk'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.num_heads = num_heads
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.gate_fn = ACT2FN[gate_fn]
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
self.gk_proj = nn.Linear(hidden_size, self.num_heads)
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
if gate_fn == 'swish' and fuse_norm:
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
self.fuse_norm_and_gate = True
else:
self.fuse_norm_and_gate = False
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
self.gate_logit_normalizer = gate_logit_normalizer
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, x):
mode = self.mode
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
gk = rearrange(self.gk_proj(x), 'b n h -> b h n')
gk = (F.logsigmoid(gk) / self.gate_logit_normalizer)
if mode == 'chunk':
o = chunk_simple_gla(q, k, v, gk)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
o = rearrange(o, 'b h l d -> b l h d')
g = self.g_proj(x)
if self.fuse_norm_and_gate:
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
o = self.g_norm_swish_gate(o, g)
o = rearrange(o, 'b l h d -> b l (h d)')
else:
o = self.g_norm(o)
o = rearrange(o, 'b l h d -> b l (h d)')
o = o * self.gate_fn(g)
o = self.o_proj(o)
return o
if __name__ == '__main__':
batch = 4
seq_len = 1024
hidden_size = 2048
x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
model = SimpleGatedLinearAttention(hidden_size=hidden_size, mode='chunk').to(torch.bfloat16).cuda()
y = model(x)
print(y.shape)
y.sum().backward()
print(x.grad.shape)

29
finetune/lora/v6/fla/models/__init__.py vendored Normal file
View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
from fla.models.delta_net import (DeltaNetConfig, DeltaNetForCausalLM,
DeltaNetModel)
from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel
from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model
from fla.models.linear_attn import (LinearAttentionConfig,
LinearAttentionForCausalLM,
LinearAttentionModel)
from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel
from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
from fla.models.transformer import (TransformerConfig, TransformerForCausalLM,
TransformerModel)
__all__ = [
'ABCConfig', 'ABCForCausalLM', 'ABCModel',
'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
'GLAConfig', 'GLAForCausalLM', 'GLAModel',
'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel',
'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model',
'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
'MambaConfig', 'MambaForCausalLM', 'MambaModel',
'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'
]

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.abc.configuration_abc import ABCConfig
from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
AutoConfig.register(ABCConfig.model_type, ABCConfig)
AutoModel.register(ABCConfig, ABCModel)
AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']

View File

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class ABCConfig(PretrainedConfig):
model_type = 'abc'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
gate_low_rank_dim: int = 16,
clamp_min: float = -32,
clamp_max: float = 32,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
num_hidden_layers: int = 24,
num_heads: int = 4,
num_slots: Optional[int] = 64,
use_short_conv: bool = True,
conv_size: int = 4,
share_conv_kernel: bool = True,
exapnd_k: float = 0.5,
exapnd_v: float = 1,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
initializer_range: float = 0.02,
tie_word_embeddings: bool = False,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.gate_low_rank_dim = gate_low_rank_dim
self.clamp_min = clamp_min
self.clamp_max = clamp_max
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_slots = num_slots
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.share_conv_kernel = share_conv_kernel
self.expand_k = exapnd_k
self.expand_v = exapnd_v
self.hidden_act = hidden_act
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_cross_entropy = fuse_cross_entropy
self.fuse_norm = fuse_norm
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,394 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from fla.layers.abc import ABCAttention
from fla.models.abc.configuration_abc import ABCConfig
from fla.models.utils import RecurrentCache
from fla.modules import FusedCrossEntropyLoss, RMSNorm
from fla.modules.activations import swiglu_linear
logger = logging.get_logger(__name__)
class ABCMLP(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish'
) -> ABCMLP:
super().__init__()
self.hidden_size = hidden_size
# the final number of params is `hidden_ratio * hidden_size^2`
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
class ABCBlock(nn.Module):
def __init__(self, config: ABCConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.attn = ABCAttention(
hidden_size=config.hidden_size,
expand_k=config.expand_k,
expand_v=config.expand_v,
num_heads=config.num_heads,
num_slots=config.num_slots,
use_short_conv=config.use_short_conv,
conv_size=config.conv_size,
share_conv_kernel=config.share_conv_kernel,
gate_fn=config.hidden_act,
elementwise_affine=config.elementwise_affine,
norm_eps=config.norm_eps,
clamp_min=config.clamp_min,
clamp_max=config.clamp_max,
fuse_norm=config.fuse_norm,
layer_idx=layer_idx
)
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.mlp = ABCMLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, attentions, past_key_values)
return outputs
class ABCPreTrainedModel(PreTrainedModel):
config_class = ABCConfig
supports_gradient_checkpointing = True
_no_split_modules = ['ABCBlock']
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
num_residuals_per_layer: int = 2,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["o_proj.weight", "down_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
class ABCModel(ABCPreTrainedModel):
def __init__(self, config: ABCConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions:
warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.")
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
if use_cache:
if past_key_values is None:
past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
past_key_values,
use_cache,
output_attentions
)
else:
hidden_states, attentions, past_key_values = layer(
hidden_states,
attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
if output_attentions:
all_attns += (attentions,)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict:
return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attns
)
class ABCForCausalLM(ABCPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = ABCModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def generate(self, *args, **kwargs):
try:
return super().generate(*args, **kwargs)
except AttributeError as exception:
if 'past_key_values' in str(exception):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is passed along.
if past_key_values is not None:
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs['past_key_values'] = past_key_values
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.delta_net.configuration_delta_net import \
DeltaNetConfig
from fla.models.delta_net.modeling_delta_net import (
DeltaNetForCausalLM, DeltaNetModel)
AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
AutoModel.register(DeltaNetConfig, DeltaNetModel)
AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
__all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']

View File

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class DeltaNetConfig(PretrainedConfig):
model_type = 'delta_net'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
expand_k: int = 1,
expand_v: int = 1,
use_gate: bool = False,
use_short_conv: bool = True,
conv_size: int = 4,
share_conv_kernel: bool = False,
use_rope: bool = False,
use_beta: bool = True,
use_output_norm: bool = True,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
num_hidden_layers: int = 24,
num_heads: int = 4,
attn_mode: str = "chunk",
qk_norm: str = 'l2',
qk_activation: str = 'silu',
chunk_size: int = 64,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
rms_norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_cross_entropy: bool = True,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.attn_mode = attn_mode
self.hidden_act = hidden_act
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_cross_entropy = fuse_cross_entropy
self.use_gate = use_gate
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.share_conv_kernel = share_conv_kernel
self.use_rope = use_rope
self.use_beta = use_beta
self.use_output_norm = use_output_norm
self.qk_norm = qk_norm
self.qk_activation = qk_activation
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,405 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from fla.layers.delta_net import DeltaNet
from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
from fla.models.utils import RecurrentCache
from fla.modules import FusedCrossEntropyLoss, RMSNorm
from fla.modules.activations import swiglu_linear
logger = logging.get_logger(__name__)
class DeltaNetMLP(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish'
) -> DeltaNetMLP:
super().__init__()
self.hidden_size = hidden_size
# the final number of params is `hidden_ratio * hidden_size^2`
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
class DeltaNetBlock(nn.Module):
def __init__(self, config: DeltaNetConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
self.attn = DeltaNet(
mode=config.attn_mode,
hidden_size=config.hidden_size,
expand_k=config.expand_k,
expand_v=config.expand_v,
num_heads=config.num_heads,
use_gate=config.use_gate,
use_rope=config.use_rope,
use_beta=config.use_beta,
use_short_conv=config.use_short_conv,
use_output_norm=config.use_output_norm,
conv_size=config.conv_size,
share_conv_kernel=config.share_conv_kernel,
layer_idx=layer_idx,
qk_norm=config.qk_norm,
qk_activation=config.qk_activation
)
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
self.mlp = DeltaNetMLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, attentions, past_key_values)
return outputs
class DeltaNetPreTrainedModel(PreTrainedModel):
config_class = DeltaNetConfig
supports_gradient_checkpointing = True
_no_split_modules = ['DeltaNetBlock']
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
num_residuals_per_layer: int = 2,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["o_proj.weight", "down_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
class DeltaNetModel(DeltaNetPreTrainedModel):
def __init__(self, config: DeltaNetConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions:
warnings.warn("`DeltaNetModel` does not `output_attentions` now, setting it to `False`.")
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
if use_cache:
if past_key_values is None:
past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
past_key_values,
use_cache,
output_attentions
)
else:
hidden_states, attentions, past_key_values = layer(
hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
if output_attentions:
all_attns += (attentions,)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = past_key_values
# if use_cache:
# next_cache = past_key_values.to_legacy_cache()
if not return_dict:
return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attns
)
class DeltaNetForCausalLM(DeltaNetPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = DeltaNetModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def generate(self, *args, **kwargs):
try:
return super().generate(*args, **kwargs)
except AttributeError as exception:
if 'past_key_values' in str(exception):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is passed along.
if past_key_values is not None:
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
# breakpoint()
input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {'input_ids': input_ids.contiguous()}
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.gla.configuration_gla import GLAConfig
from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
AutoConfig.register(GLAConfig.model_type, GLAConfig)
AutoModel.register(GLAConfig, GLAModel)
AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
__all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']

View File

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class GLAConfig(PretrainedConfig):
model_type = 'gla'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
expand_k: int = 0.5,
expand_v: int = 1,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
num_hidden_layers: int = 24,
num_heads: int = 4,
num_kv_heads: Optional[int] = None,
feature_map: Optional[str] = None,
attn_mode: str = "chunk",
use_short_conv: bool = False,
conv_size: int = 4,
share_conv_kernel: bool = True,
use_output_gate: bool = True,
clamp_min: Optional[float] = None,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
use_gk: bool = True,
use_gv: bool = False,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.feature_map = feature_map
self.attn_mode = attn_mode
self.clamp_min = clamp_min
self.hidden_act = hidden_act
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_gk = use_gk
self.use_gv = use_gv
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_norm = fuse_norm
self.fuse_cross_entropy = fuse_cross_entropy
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.share_conv_kernel = share_conv_kernel
self.use_output_gate = use_output_gate
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,403 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from fla.layers.gla import GatedLinearAttention
from fla.models.gla.configuration_gla import GLAConfig
from fla.models.utils import RecurrentCache
from fla.modules import FusedCrossEntropyLoss, RMSNorm
from fla.modules.activations import swiglu_linear
logger = logging.get_logger(__name__)
class GLAMLP(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish'
) -> GLAMLP:
super().__init__()
self.hidden_size = hidden_size
# the final number of params is `hidden_ratio * hidden_size^2`
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
class GLABlock(nn.Module):
def __init__(self, config: GLAConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.attn = GatedLinearAttention(
mode=config.attn_mode,
hidden_size=config.hidden_size,
expand_k=config.expand_k,
expand_v=config.expand_v,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
feature_map=config.feature_map,
use_short_conv=config.use_short_conv,
conv_size=config.conv_size,
share_conv_kernel=config.share_conv_kernel,
use_output_gate=config.use_output_gate,
gate_fn=config.hidden_act,
elementwise_affine=config.elementwise_affine,
norm_eps=config.norm_eps,
clamp_min=config.clamp_min,
fuse_norm=config.fuse_norm,
layer_idx=layer_idx
)
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.mlp = GLAMLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, attentions, past_key_values)
return outputs
class GLAPreTrainedModel(PreTrainedModel):
config_class = GLAConfig
supports_gradient_checkpointing = True
_no_split_modules = ['GLABlock']
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
num_residuals_per_layer: int = 2,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["o_proj.weight", "down_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
class GLAModel(GLAPreTrainedModel):
def __init__(self, config: GLAConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions:
warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.")
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
if use_cache:
if past_key_values is None:
past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
past_key_values,
use_cache,
output_attentions
)
else:
hidden_states, attentions, past_key_values = layer(
hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
if output_attentions:
all_attns += (attentions,)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict:
return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attns
)
class GLAForCausalLM(GLAPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = GLAModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def generate(self, *args, **kwargs):
try:
return super().generate(*args, **kwargs)
except AttributeError as exception:
if 'past_key_values' in str(exception):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is passed along.
if past_key_values is not None:
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {'input_ids': input_ids.contiguous()}
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.hgrn.configuration_hgrn import HGRNConfig
from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel
AutoConfig.register(HGRNConfig.model_type, HGRNConfig)
AutoModel.register(HGRNConfig, HGRNModel)
AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM)
__all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel']

View File

@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class HGRNConfig(PretrainedConfig):
model_type = 'hgrn'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
attn_mode: str = "chunk",
vocab_size: int = 32000,
hidden_size: int = 2048,
num_hidden_layers: int = 24,
num_heads: Optional[int] = 1,
expand_ratio: Optional[int] = 1,
use_short_conv: bool = False,
conv_size: int = 4,
share_conv_kernel: bool = True,
use_lower_bound: bool = True,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_cross_entropy: bool = True,
**kwargs
):
self.attn_mode = attn_mode
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.expand_ratio = expand_ratio
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.share_conv_kernel = share_conv_kernel
self.use_lower_bound = use_lower_bound
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_cross_entropy = fuse_cross_entropy
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,407 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from fla.layers.hgrn import HGRNAttention
from fla.models.hgrn.configuration_hgrn import HGRNConfig
from fla.models.utils import RecurrentCache
from fla.modules import FusedCrossEntropyLoss, RMSNorm
from fla.modules.activations import swiglu_linear
logger = logging.get_logger(__name__)
class HGRNMLP(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish'
) -> HGRNMLP:
super().__init__()
self.hidden_size = hidden_size
# the final number of params is `hidden_ratio * hidden_size^2`
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
class HGRNBlock(nn.Module):
def __init__(self, config: HGRNConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.attn = HGRNAttention(
mode=config.attn_mode,
hidden_size=config.hidden_size,
num_heads=config.num_heads,
expand_ratio=config.expand_ratio,
use_short_conv=config.use_short_conv,
conv_size=config.conv_size,
share_conv_kernel=config.share_conv_kernel,
elementwise_affine=config.elementwise_affine,
norm_eps=config.norm_eps,
layer_idx=layer_idx
)
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.mlp = HGRNMLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
lower_bound: Optional[torch.Tensor] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
lower_bound=lower_bound
)
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, attentions, past_key_values)
return outputs
class HGRNPreTrainedModel(PreTrainedModel):
config_class = HGRNConfig
supports_gradient_checkpointing = True
_no_split_modules = ['HGRNBlock']
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
num_residuals_per_layer: int = 2,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["o_proj.weight", "down_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
class HGRNModel(HGRNPreTrainedModel):
def __init__(self, config: HGRNConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
if config.use_lower_bound:
self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
self.layers = nn.ModuleList([HGRNBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions:
warnings.warn("`HGRNModel` does not `output_attentions` now, setting it to `False`.")
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
if use_cache:
if past_key_values is None:
past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
if self.config.use_lower_bound:
lower_bounds = self.lower_bounds.softmax(0)
lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
if self.gradient_checkpointing and self.training:
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
past_key_values,
use_cache,
output_attentions,
lower_bound
)
else:
hidden_states, attentions, past_key_values = layer(
hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
lower_bound=lower_bound
)
if output_attentions:
all_attns += (attentions,)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict:
return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attns
)
class HGRNForCausalLM(HGRNPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = HGRNModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def generate(self, *args, **kwargs):
try:
return super().generate(*args, **kwargs)
except AttributeError as exception:
if 'past_key_values' in str(exception):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is passed along.
if past_key_values is not None:
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {'input_ids': input_ids.contiguous()}
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model
AutoConfig.register(HGRN2Config.model_type, HGRN2Config)
AutoModel.register(HGRN2Config, HGRN2Model)
AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM)
__all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model']

View File

@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class HGRN2Config(PretrainedConfig):
model_type = 'hgrn2'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
num_hidden_layers: int = 24,
attn_mode: str = "chunk",
num_heads: Optional[int] = None,
expand_ratio: Optional[int] = 128,
use_short_conv: bool = False,
conv_size: int = 4,
share_conv_kernel: bool = True,
use_lower_bound: bool = True,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_cross_entropy: bool = True,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.attn_mode = attn_mode
self.num_heads = num_heads
self.expand_ratio = expand_ratio
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.share_conv_kernel = share_conv_kernel
self.use_lower_bound = use_lower_bound
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_cross_entropy = fuse_cross_entropy
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,407 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from fla.layers.hgrn2 import HGRN2Attention
from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
from fla.models.utils import RecurrentCache
from fla.modules import FusedCrossEntropyLoss, RMSNorm
from fla.modules.activations import swiglu_linear
logger = logging.get_logger(__name__)
class HGRN2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish'
) -> HGRN2MLP:
super().__init__()
self.hidden_size = hidden_size
# the final number of params is `hidden_ratio * hidden_size^2`
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
class HGRN2Block(nn.Module):
def __init__(self, config: HGRN2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.attn = HGRN2Attention(
mode=config.attn_mode,
hidden_size=config.hidden_size,
num_heads=config.num_heads,
expand_ratio=config.expand_ratio,
use_short_conv=config.use_short_conv,
conv_size=config.conv_size,
share_conv_kernel=config.share_conv_kernel,
elementwise_affine=config.elementwise_affine,
norm_eps=config.norm_eps,
layer_idx=layer_idx
)
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.mlp = HGRN2MLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
lower_bound: Optional[torch.Tensor] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
lower_bound=lower_bound
)
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, attentions, past_key_values)
return outputs
class HGRN2PreTrainedModel(PreTrainedModel):
config_class = HGRN2Config
supports_gradient_checkpointing = True
_no_split_modules = ['HGRN2Block']
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
num_residuals_per_layer: int = 2,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["o_proj.weight", "down_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
class HGRN2Model(HGRN2PreTrainedModel):
def __init__(self, config: HGRN2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
if config.use_lower_bound:
self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
self.layers = nn.ModuleList([HGRN2Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions:
warnings.warn("`HGRN2Model` does not `output_attentions` now, setting it to `False`.")
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
if use_cache:
if past_key_values is None:
past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
if self.config.use_lower_bound:
lower_bounds = self.lower_bounds.softmax(0)
lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
if self.gradient_checkpointing and self.training:
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
past_key_values,
use_cache,
output_attentions,
lower_bound
)
else:
hidden_states, attentions, past_key_values = layer(
hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
lower_bound=lower_bound
)
if output_attentions:
all_attns += (attentions,)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict:
return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attns
)
class HGRN2ForCausalLM(HGRN2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = HGRN2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def generate(self, *args, **kwargs):
try:
return super().generate(*args, **kwargs)
except AttributeError as exception:
if 'past_key_values' in str(exception):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is passed along.
if past_key_values is not None:
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {'input_ids': input_ids.contiguous()}
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.linear_attn.configuration_linear_attn import \
LinearAttentionConfig
from fla.models.linear_attn.modeling_linear_attn import (
LinearAttentionForCausalLM, LinearAttentionModel)
AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig)
AutoModel.register(LinearAttentionConfig, LinearAttentionModel)
AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM)
__all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel']

View File

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class LinearAttentionConfig(PretrainedConfig):
model_type = 'linear_attn'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
expand_k: int = 1,
expand_v: int = 1,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
num_hidden_layers: int = 24,
num_heads: int = 4,
attn_mode: str = "fused_chunk",
feature_map: str = "elementwise_product",
tie_feature_map_qk: bool = False,
norm_q: bool = False,
norm_k: bool = False,
norm_feature_map: bool = False,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_cross_entropy: bool = True,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.attn_mode = attn_mode
self.feature_map = feature_map
self.tie_feature_map_qk = tie_feature_map_qk
self.norm_q = norm_q
self.norm_k = norm_k
self.norm_feature_map = norm_feature_map
self.hidden_act = hidden_act
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_cross_entropy = fuse_cross_entropy
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,424 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from fla.layers.linear_attn import LinearAttention
from fla.models.linear_attn.configuration_linear_attn import \
LinearAttentionConfig
from fla.modules import FusedCrossEntropyLoss, RMSNorm
from fla.modules.activations import swiglu_linear
logger = logging.get_logger(__name__)
class LinearAttentionMLP(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish'
) -> LinearAttentionMLP:
super().__init__()
self.hidden_size = hidden_size
# the final number of params is `hidden_ratio * hidden_size^2`
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
class LinearAttentionBlock(nn.Module):
def __init__(self, config: LinearAttentionConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.attn = LinearAttention(
hidden_size=config.hidden_size,
expand_k=config.expand_k,
expand_v=config.expand_v,
num_heads=config.num_heads,
mode=config.attn_mode,
feature_map=config.feature_map,
tie_feature_map_qk=config.tie_feature_map_qk,
norm_q=config.norm_q,
norm_k=config.norm_k,
do_feature_map_norm=config.norm_feature_map,
elementwise_affine=config.elementwise_affine,
norm_eps=config.norm_eps,
layer_idx=layer_idx
)
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.mlp = LinearAttentionMLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
# currently not supported
attn_weights, present_key_value = None, None
hidden_states = self.attn_norm(hidden_states)
hidden_states = self.attn(hidden_states)
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class LinearAttentionPreTrainedModel(PreTrainedModel):
config_class = LinearAttentionConfig
supports_gradient_checkpointing = True
_no_split_modules = ['LinearAttentionBlock']
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
num_residuals_per_layer: int = 2,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["o_proj.weight", "down_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
class LinearAttentionModel(LinearAttentionPreTrainedModel):
def __init__(self, config: LinearAttentionConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LinearAttentionBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions:
warnings.warn(
"`LinearAttentionModel` does not support output attention weights now, "
"so `output_attentions` is set to `False`."
)
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
_, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
_, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
# embed positions
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LinearAttentionForCausalLM(LinearAttentionPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LinearAttentionModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def generate(self, *args, **kwargs):
try:
return super().generate(*args, **kwargs)
except AttributeError as exc:
# Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
if 'past_key_values' in str(exc):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exc
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
state: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs
):
# only last token for inputs_ids if the state is passed along.
if state is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and state is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs["state"] = state
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.mamba.configuration_mamba import MambaConfig
from fla.models.mamba.modeling_mamba import (MambaBlock, MambaForCausalLM,
MambaModel)
AutoConfig.register(MambaConfig.model_type, MambaConfig, True)
AutoModel.register(MambaConfig, MambaModel, True)
AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True)
__all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock']

View File

@ -0,0 +1,156 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MAMBA configuration"""
import math
from transformers.configuration_utils import PretrainedConfig
class MambaConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MAMBA
[state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50280):
Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MambaModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the model.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
The epsilon to use in the layer normalization layers.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 0):
The id of the beginning of sentence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 0):
The id of the end of sentence token in the vocabulary.
expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
use_bias (`bool`, *optional*, defaults to `False`):
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
use_conv_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias in the convolution layer of the mixer block.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
initializer_range (`float`, *optional*, defaults to 0.1):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
Whether or not residuals should be in `float32`.
If set to `False` residuals will keep the same `dtype` as the rest of the model
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
Rank of the the discretization projection matrix.
`"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
time_step_scale (`float`, *optional*, defaults to 1.0):
Scale used used to scale `dt_proj.bias`.
time_step_min (`float`, *optional*, defaults to 0.001):
Minimum `time_step` used to bound `dt_proj.bias`.
time_step_max (`float`, *optional*, defaults to 0.1):
Maximum `time_step` used to bound `dt_proj.bias`.
time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
time_step_floor (`float`, *optional*, defaults to 0.0001):
Minimum clamping value of the `dt_proj.bias` layer initialization.
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
Example:
```python
>>> from transformers import MambaConfig, MambaModel
>>> # Initializing a Mamba configuration
>>> configuration = MambaConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = MambaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mamba"
def __init__(
self,
vocab_size=32000,
hidden_size=2048,
state_size=16,
num_hidden_layers=48,
layer_norm_epsilon=1e-5,
pad_token_id= 0,
bos_token_id= 1,
eos_token_id= 2,
expand=2,
conv_kernel=4,
use_bias=False,
use_conv_bias=True,
hidden_act="silu",
initializer_range=0.1,
residual_in_fp32=False,
time_step_rank="auto",
time_step_scale=1.0,
time_step_min=0.001,
time_step_max=0.1,
time_step_init_scheme="random",
time_step_floor=1e-4,
rescale_prenorm_residual=False,
use_cache=True,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
tie_word_embeddings: bool = False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.state_size = state_size
self.num_hidden_layers = num_hidden_layers
self.layer_norm_epsilon = layer_norm_epsilon
self.conv_kernel = conv_kernel
self.expand = expand
self.intermediate_size = int(expand * self.hidden_size)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.use_bias = use_bias
self.use_conv_bias = use_conv_bias
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
self.time_step_scale = time_step_scale
self.time_step_min = time_step_min
self.time_step_max = time_step_max
self.time_step_init_scheme = time_step_init_scheme
self.time_step_floor = time_step_floor
self.rescale_prenorm_residual = rescale_prenorm_residual
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.fuse_cross_entropy = fuse_cross_entropy
self.fuse_norm = fuse_norm
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)

View File

@ -0,0 +1,605 @@
# coding=utf-8
# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MAMBA model."""
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging
from fla.models.mamba.configuration_mamba import MambaConfig
from fla.modules import FusedCrossEntropyLoss, RMSNorm
logger = logging.get_logger(__name__)
try:
from mamba_ssm.ops.selective_scan_interface import (mamba_inner_fn,
selective_scan_fn)
from mamba_ssm.ops.triton.selective_state_update import \
selective_state_update
except ImportError:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_update, causal_conv1d_fn = None, None
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
class MambaCache:
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.seqlen_offset = 0
self.dtype = dtype
intermediate_size = config.intermediate_size
ssm_state_size = config.state_size
conv_kernel_size = config.conv_kernel
self.conv_states = {
i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
class MambaMixer(nn.Module):
"""
Compute , A, B, C, and D the state space parameters and compute the `contextualized_states`.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
"""
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = config.time_step_rank
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size,
out_channels=self.intermediate_size,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.intermediate_size,
padding=config.conv_kernel - 1,
)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
# selective projection used to make dt, B and C input dependant
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
# time step projection (discretization)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
A = A.expand(self.intermediate_size, -1).contiguous()
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.intermediate_size))
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of "
"`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the naive implementation. "
"To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
contextualized_states = mamba_inner_fn(
projected_states,
self.conv1d.weight,
self.conv1d.bias if self.use_conv_bias else None,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias.float() if self.use_bias else None,
-torch.exp(self.A_log.float()),
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else:
hidden_states, gate = projected_states.chunk(2, dim=1)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_params.seqlen_offset > 0:
hidden_states = causal_conv1d_update(
hidden_states.squeeze(-1),
cache_params.conv_states[self.layer_idx],
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx].copy_(conv_states)
hidden_states = causal_conv1d_fn(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
A = -torch.exp(self.A_log.float())
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
if cache_params is not None and cache_params.seqlen_offset > 0:
scan_outputs = selective_state_update(
cache_params.ssm_states[self.layer_idx],
hidden_states[..., 0],
discrete_time_step[..., 0],
A,
B[:, 0],
C[:, 0],
self.D,
gate[..., 0],
time_proj_bias,
dt_softplus=True,
).unsqueeze(-1)
else:
scan_outputs, ssm_state = selective_scan_fn(
hidden_states,
discrete_time_step,
A,
B.transpose(1, 2),
C.transpose(1, 2),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
return_last_state=True,
)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states
# fmt: off
def slow_forward(self, input_states, cache_params: Optional[MambaCache] = None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
# [batch, 2 * intermediate_size, seq_len]
projected_states = self.in_proj(input_states).transpose(1, 2)
hidden_states, gate = projected_states.chunk(2, dim=1)
# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
if cache_params.seqlen_offset > 0:
# [batch, intermediate_size, conv_kernel_size]
conv_state = cache_params.conv_states[self.layer_idx]
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
conv_state[:, :, -1] = hidden_states[:, :, 0]
cache_params.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
# [batch, intermediate_size, 1] : decoding
hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)
else:
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx].copy_(conv_state)
# [batch, intermediate_size, seq_len]
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
else:
ssm_state = torch.zeros(
(batch_size, self.intermediate_size, self.ssm_state_size),
device=hidden_states.device, dtype=dtype
)
# [batch, intermediate_size, seq_len]
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
# [batch, seq_len, intermediate_size]
discrete_time_step = self.dt_proj(time_step)
# [batch, intermediate_size, seq_len]
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2)
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
# [intermediate_size, ssm_state_size]
A = -torch.exp(self.A_log.float())
# [batch, intermediate_size, seq_len, ssm_state_size]
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None])
# [batch, intermediade_size, seq_len, ssm_state_size]
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = []
for i in range(seq_len):
# [batch, intermediade_size, ssm_state]
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
# [batch, intermediade_size, 1]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))
scan_outputs.append(scan_output[:, :, 0])
# [batch, seq_len, intermediade_size]
scan_output = torch.stack(scan_outputs, dim=-1)
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))
if cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
# 4. Final linear projection
# [batch, seq_len, hidden_size]
contextualized_states = self.out_proj(scan_output.transpose(1, 2))
return contextualized_states
# fmt: on
def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params)
return self.slow_forward(hidden_states, cache_params)
class MambaBlock(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.residual_in_fp32 = config.residual_in_fp32
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = MambaMixer(config, layer_idx=layer_idx)
def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
residual = hidden_states
hidden_states = self.norm(hidden_states)
# if self.residual_in_fp32:
# residual = residual.to(torch.float32)
hidden_states = self.mixer(hidden_states, cache_params=cache_params)
hidden_states = residual + hidden_states
return hidden_states
class MambaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MambaConfig
base_model_prefix = "backbone"
_no_split_modules = ["MambaBlock"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, MambaMixer):
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True
dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
if self.config.time_step_init_scheme == "constant":
nn.init.constant_(module.dt_proj.weight, dt_init_std)
elif self.config.time_step_init_scheme == "random":
nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
dt = torch.exp(
torch.rand(self.config.intermediate_size)
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
+ math.log(self.config.time_step_min)
).clamp(min=self.config.time_step_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
module.dt_proj.bias.copy_(inv_dt)
module.dt_proj.bias._no_reinit = True
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=self.config.initializer_range)
if self.config.rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(self.config.num_layers)
@dataclass
class MambaOutput(ModelOutput):
"""
Class for the MAMBA model outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
cache_params (`MambaCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*,
returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: Optional[torch.FloatTensor] = None
cache_params: Optional[MambaCache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MambaCausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (`MambaCache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*,
returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
cache_params: Optional[MambaCache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class MambaModel(MambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
) -> Union[Tuple, MambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if cache_params is None and use_cache:
cache_params = MambaCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params)
else:
hidden_states = mixer_block(hidden_states, cache_params=cache_params)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if use_cache:
cache_params.seqlen_offset += inputs_embeds.shape[1]
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
return MambaOutput(
last_hidden_state=hidden_states,
cache_params=cache_params if use_cache else None,
hidden_states=all_hidden_states,
)
class MambaForCausalLM(MambaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.backbone = MambaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_input_embeddings(self):
return self.backbone.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
return self.backbone.set_input_embeddings(new_embeddings)
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
) -> Dict[str, Any]:
model_kwargs["cache_params"] = outputs.get("cache_params", None)
return model_kwargs
def prepare_inputs_for_generation(
self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs
):
# only last token for inputs_ids if the state is passed along.
if cache_params is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs["cache_params"] = cache_params
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, MambaCausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
mamba_outputs = self.backbone(
input_ids,
cache_params=cache_params,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
hidden_states = mamba_outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + mamba_outputs[1:]
return (loss,) + output if loss is not None else output
return MambaCausalLMOutput(
loss=loss,
logits=logits,
cache_params=mamba_outputs.cache_params,
hidden_states=mamba_outputs.hidden_states,
)

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.retnet.configuration_retnet import RetNetConfig
from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel
AutoConfig.register(RetNetConfig.model_type, RetNetConfig)
AutoModel.register(RetNetConfig, RetNetModel)
AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM)
__all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel']

View File

@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class RetNetConfig(PretrainedConfig):
model_type = 'retnet'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
expand_k: int = 1,
expand_v: int = 2,
hidden_ratio: Optional[int] = 2,
intermediate_size: Optional[int] = None,
num_hidden_layers: int = 24,
num_heads: int = 8,
num_kv_heads: Optional[int] = None,
feature_map: Optional[str] = None,
attn_mode: str = "fused_chunk",
hidden_act: str = "swish",
use_short_conv: bool = False,
conv_size: int = 4,
share_conv_kernel: bool = True,
use_output_gate: bool = True,
max_position_embeddings: int = 2048,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
**kwargs
) -> RetNetConfig:
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.feature_map = feature_map
self.attn_mode = attn_mode
self.hidden_act = hidden_act
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.share_conv_kernel = share_conv_kernel
self.use_output_gate = use_output_gate
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_norm = fuse_norm
self.fuse_cross_entropy = fuse_cross_entropy
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,410 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from fla.layers.multiscale_retention import MultiScaleRetention
from fla.models.retnet.configuration_retnet import RetNetConfig
from fla.models.utils import RecurrentCache
from fla.modules import FusedCrossEntropyLoss, RMSNorm
from fla.modules.activations import swiglu_linear
logger = logging.get_logger(__name__)
class RetNetMLP(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish'
) -> RetNetMLP:
super().__init__()
self.hidden_size = hidden_size
# the final number of params is `hidden_ratio * hidden_size^2`
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
class RetNetBlock(nn.Module):
def __init__(self, config: RetNetConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.attn = MultiScaleRetention(
mode=config.attn_mode,
hidden_size=config.hidden_size,
expand_k=config.expand_k,
expand_v=config.expand_v,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
feature_map=config.feature_map,
use_output_gate=config.use_output_gate,
gate_fn=config.hidden_act,
elementwise_affine=config.elementwise_affine,
norm_eps=config.norm_eps,
fuse_norm=config.fuse_norm,
layer_idx=layer_idx
)
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.mlp = RetNetMLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, attentions, past_key_values)
return outputs
class RetNetPreTrainedModel(PreTrainedModel):
config_class = RetNetConfig
supports_gradient_checkpointing = True
_no_split_modules = ['RetNetBlock']
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
num_residuals_per_layer: int = 2,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["o_proj.weight", "down_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
class RetNetModel(RetNetPreTrainedModel):
def __init__(self, config: RetNetConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[RetNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions:
warnings.warn(
"`RetNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
)
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_len = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
if use_cache:
if past_key_values is None:
past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
past_key_values,
use_cache,
output_attentions
)
else:
hidden_states, attentions, past_key_values = layer(
hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
if output_attentions:
all_attns += (attentions,)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict:
return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attns
)
class RetNetForCausalLM(RetNetPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = RetNetModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def generate(self, *args, **kwargs):
try:
return super().generate(*args, **kwargs)
except AttributeError as exception:
# Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
if 'past_key_values' in str(exception):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is passed along.
if past_key_values is not None:
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {'input_ids': input_ids.contiguous()}
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config
from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model
AutoConfig.register(RWKV6Config.model_type, RWKV6Config)
AutoModel.register(RWKV6Config, RWKV6Model)
AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM)
__all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model']

View File

@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class RWKV6Config(PretrainedConfig):
model_type = 'rwkv6'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
attn_mode: str = "chunk",
vocab_size: int = 32000,
hidden_size: int = 2048,
expand_k: int = 0.5,
expand_v: int = 1,
hidden_ratio: Optional[int] = 3.5,
intermediate_size: Optional[int] = None,
use_glu: Optional[bool] = False,
num_hidden_layers: int = 24,
num_heads: int = 4,
proj_low_rank_dim: int = 32,
gate_low_rank_dim: int = 64,
hidden_act: str = "sqrelu",
max_position_embeddings: int = 2048,
eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.use_glu = use_glu
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.proj_low_rank_dim = proj_low_rank_dim
self.gate_low_rank_dim = gate_low_rank_dim
self.attn_mode = attn_mode
self.hidden_act = hidden_act
self.eps = eps
self.use_cache = use_cache
self.initializer_range = initializer_range
self.fuse_norm = fuse_norm
self.fuse_cross_entropy = fuse_cross_entropy
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,443 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from fla.layers.rwkv6 import LerpLinear, RWKV6Attention
from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config
from fla.models.utils import RecurrentCache
from fla.modules import FusedCrossEntropyLoss, LayerNorm
from fla.modules.activations import ACT2FN, swiglu_linear
logger = logging.get_logger(__name__)
class RWKV6FeedForward(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'sqrelu',
layer_idx: int = None
) -> RWKV6FeedForward:
super().__init__()
self.hidden_size = hidden_size
if hidden_ratio is None:
hidden_ratio = 3.5
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio)
intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = LerpLinear(hidden_size, intermediate_size)
self.value = nn.Linear(intermediate_size, hidden_size)
self.receptance = LerpLinear(hidden_size, hidden_size)
self.act_fn = ACT2FN[hidden_act]
self.layer_idx = layer_idx
def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None) -> torch.Tensor:
if state is not None:
raise NotImplementedError("Past state is not yet supported in `RWKV6FeedForward`.")
shifted = self.time_shift(x)
if len(shifted.shape) == 2:
shifted = shifted.unsqueeze(1)
delta = shifted - x
key = self.act_fn(self.key(x, delta))
value = self.value(key)
receptance = self.receptance(x, delta)
return receptance.sigmoid() * value
class RWKV6GLU(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish',
layer_idx: int = None
) -> RWKV6GLU:
super().__init__()
self.hidden_size = hidden_size
# the final number of params is `hidden_ratio * hidden_size^2`
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.layer_idx = layer_idx
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
class RWKV6Block(nn.Module):
def __init__(self, config: RWKV6Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = LayerNorm(hidden_size=config.hidden_size, eps=config.eps)
self.attn = RWKV6Attention(
mode=config.attn_mode,
hidden_size=config.hidden_size,
expand_k=config.expand_k,
expand_v=config.expand_v,
num_heads=config.num_heads,
proj_low_rank_dim=config.proj_low_rank_dim,
gate_low_rank_dim=config.gate_low_rank_dim,
eps=config.eps,
fuse_norm=config.fuse_norm,
layer_idx=layer_idx
)
self.ffn_norm = LayerNorm(hidden_size=config.hidden_size, eps=config.eps)
self.ffn = (RWKV6GLU if config.use_glu else RWKV6FeedForward)(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
layer_idx=layer_idx
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
hidden_states, residual = self.ffn_norm(hidden_states, residual, True)
hidden_states = self.ffn(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, attentions, past_key_values)
return outputs
class RWKV6PreTrainedModel(PreTrainedModel):
config_class = RWKV6Config
supports_gradient_checkpointing = True
_no_split_modules = ['RWKV6Block']
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
num_residuals_per_layer: int = 2,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Parameter):
nn.init.normal_(module, mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["o_proj.weight", "down_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
class RWKV6Model(RWKV6PreTrainedModel):
def __init__(self, config: RWKV6Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([RWKV6Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = LayerNorm(config.hidden_size, eps=config.eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions:
warnings.warn("`RWKV6Model` does not `output_attentions` now, setting it to `False`.")
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
if use_cache:
if past_key_values is None:
past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
past_key_values,
use_cache,
output_attentions
)
else:
hidden_states, attentions, past_key_values = layer(
hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
if output_attentions:
all_attns += (attentions,)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict:
return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attns
)
class RWKV6ForCausalLM(RWKV6PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = RWKV6Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def generate(self, *args, **kwargs):
try:
return super().generate(*args, **kwargs)
except AttributeError as exception:
if 'past_key_values' in str(exception):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is passed along.
if past_key_values is not None:
if not isinstance(past_key_values, RecurrentCache):
past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {'input_ids': input_ids.contiguous()}
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.transformer.configuration_transformer import TransformerConfig
from fla.models.transformer.modeling_transformer import (
TransformerForCausalLM, TransformerModel)
AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
AutoModel.register(TransformerConfig, TransformerModel)
AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel']

View File

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
class TransformerConfig(PretrainedConfig):
model_type = 'transformer'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 2048,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
num_hidden_layers: int = 24,
num_heads: int = 32,
num_kv_heads: int = None,
hidden_act: str = "swish",
max_position_embeddings: int = 2048,
initializer_range: float = 0.02,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
attention_bias: bool = False,
fuse_norm: bool = True,
fuse_cross_entropy: bool = True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.fuse_cross_entropy = fuse_cross_entropy
self.fuse_norm = fuse_norm
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,522 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from fla.models.transformer.configuration_transformer import TransformerConfig
from fla.modules import FusedCrossEntropyLoss, RMSNorm, RotaryEmbedding
from fla.modules.activations import swiglu_linear
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import (index_first_axis, pad_input,
unpad_input)
except ImportError:
warnings.warn("Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`")
flash_attn_func = None
logger = logging.get_logger(__name__)
class TransformerAttention(nn.Module):
def __init__(
self,
config: TransformerConfig,
layer_idx: Optional[int] = None,
**kwargs
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.num_heads = config.num_heads
if config.num_kv_heads is None:
self.num_kv_heads = self.num_heads
else:
self.num_kv_heads = config.num_kv_heads
self.num_kv_groups = config.num_heads // self.num_kv_heads
self.hidden_size = config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.kv_dim = self.num_kv_heads * self.head_dim
self.max_position_embeddings = config.max_position_embeddings
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.rotary = RotaryEmbedding(self.head_dim)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
batch_size, q_len, _ = hidden_states.size()
q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads)
v = rearrange(self.v_proj(hidden_states), 'b t (h d) -> b h t d', h=self.num_kv_heads)
seqlen_offset = 0
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
if attention_mask is not None:
# to deliminate the offsets of padding tokens
seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
q, k = self.rotary(q, k, seqlen_offset, self.max_position_embeddings)
k = rearrange(k, 'b t h d -> b h t d')
if past_key_values is not None:
k, v = past_key_values.update(k, v, self.layer_idx)
k, v = rearrange(k, 'b h t d -> b t h d'), rearrange(v, 'b h t d -> b t h d')
if self.num_kv_groups > 1:
k = rearrange(k.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d')
v = rearrange(v.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d')
if flash_attn_func is None:
raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
# Contains at least one padding token in the sequence
if attention_mask is not None:
q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
o = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=True
)
o = pad_input(o, indices_q, batch_size, q_len)
else:
o = flash_attn_func(q, k, v, causal=True)
o = o.reshape(batch_size, q_len, self.hidden_size)
o = self.o_proj(o)
if not output_attentions:
attentions = None
return o, attentions, past_key_values
def _upad_input(self, q, k, v, attention_mask, q_len):
seqlens = attention_mask.sum(-1, dtype=torch.int32)
indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_k = seqlens.max().item()
cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
batch_size, seq_len, num_key_value_heads, head_dim = k.shape
k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
if q_len == seq_len:
q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_q = max_seqlen_k
indices_q = indices_k
elif q_len == 1:
max_seqlen_q = 1
# There is a memcpy here, that is very bad.
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
indices_q = cu_seqlens_q[:-1]
q = q.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -q_len:]
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
class TransformerMLP(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish'
) -> TransformerMLP:
super().__init__()
self.hidden_size = hidden_size
# the final number of params is `hidden_ratio * hidden_size^2`
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
class TransformerBlock(nn.Module):
def __init__(self, config: TransformerConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.attn = TransformerAttention(
config=config,
layer_idx=layer_idx
)
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.mlp = TransformerMLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attentions,)
if use_cache:
outputs += (past_key_values,)
return outputs
class TransformerPreTrainedModel(PreTrainedModel):
config_class = TransformerConfig
supports_gradient_checkpointing = True
_no_split_modules = ['TransformerBlock']
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
num_residuals_per_layer: int = 2,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["o_proj.weight", "down_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
class TransformerModel(TransformerPreTrainedModel):
def __init__(self, config: TransformerConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
if output_attentions:
warnings.warn(
"`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
)
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
# embed positions
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
next_decoder_cache = None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
past_key_values,
output_attentions,
use_cache
)
else:
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attns
)
class TransformerForCausalLM(TransformerPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = TransformerModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
):
# only last token for `inputs_ids` if the `past_key_values` is passed along.
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard.
# Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {'input_ids': input_ids.contiguous()}
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
else:
loss_fct = nn.CrossEntropyLoss()
# Enable model parallelism
labels = labels.to(logits.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

107
finetune/lora/v6/fla/models/utils.py vendored Normal file
View File

@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import torch
from transformers.cache_utils import Cache
class RecurrentCache(Cache):
"""
A cache used for storing hidden states produced by flash linear attention models.
It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
"""
def __init__(
self,
seen_tokens: int = 0
) -> RecurrentCache:
self.states: List[torch.Tensor] = []
self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
def __getitem__(self, layer_idx: int) -> torch.Tensor:
if layer_idx < len(self):
return self.states[layer_idx]
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
for state in self.states:
yield state
def __len__(self):
return len(self.states)
def update(
self,
state: Tuple[torch.Tensor],
layer_idx: int,
offset: Optional[int] = 1,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
"""
Updates the cache with the new `state` for the layer `layer_idx`.
Parameters:
state (`Tuple[torch.Tensor]`):
The new state to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
offset (`int`):
The offset of current fed tokens.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass.
Return:
The updated state.
"""
if isinstance(state, torch.Tensor):
state = (state,)
if len(self.states) <= layer_idx:
self.states.append(state)
else:
for i, s in enumerate(state):
self.states[layer_idx][i].copy_(s)
# update the number of seen tokens once we achieve the last layer
if layer_idx == len(self) - 1:
self._seen_tokens += offset
return state
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.states) <= layer_idx:
return 0
return self._seen_tokens
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. RecurrentCache does not have a maximum length."""
return None
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.states)):
device = self.states[layer_idx].device
self.states[layer_idx] = self.states[layer_idx].index_select(0, beam_idx.to(device))
def to_legacy_cache(self) -> Tuple[torch.Tensor]:
return tuple(self.states)
@classmethod
def from_legacy_cache(
cls,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
seen_tokens: int = 0
) -> RecurrentCache:
"""Converts a cache in the legacy cache format into an equivalent `RecurrentCache`."""
cache = cls(seen_tokens)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
cache.update(past_key_values[layer_idx], layer_idx)
return cache

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution,
ShortConvolution)
from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate,
FusedLayerNormSwishGateLinear,
FusedRMSNormSwishGate,
FusedRMSNormSwishGateLinear)
from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm,
RMSNormLinear)
from fla.modules.rotary import RotaryEmbedding
__all__ = [
'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
'FusedCrossEntropyLoss',
'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
'RotaryEmbedding'
]

View File

@ -0,0 +1,394 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2024, Tri Dao, Yu Zhang, Songlin Yang.
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from fla.utils import contiguous
sigmoid_fwd_codestring = """
template <typename T> T sigmoid_fwd(T x) {
return 1.0f / (1.0f + ::exp(-float(x)));
}
"""
sigmoid_bwd_codestring = """
template <typename T> T sigmoid_bwd(T x, T g) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
return float(g) * x_sigmoid * (1.0f - x_sigmoid);
}
"""
sigmoid_fwd = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring)
sigmoid_bwd = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring)
class SigmoidFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return sigmoid_fwd(x)
@staticmethod
def backward(ctx, dout):
x, = ctx.saved_tensors
return sigmoid_bwd(x, dout)
sigmoid = SigmoidFunction.apply
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 16}, num_warps=4),
triton.Config({'BT': 16}, num_warps=8),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=8),
triton.Config({'BT': 64}, num_warps=2),
triton.Config({'BT': 64}, num_warps=4),
triton.Config({'BT': 64}, num_warps=8),
triton.Config({'BT': 128}, num_warps=2),
triton.Config({'BT': 128}, num_warps=4),
triton.Config({'BT': 128}, num_warps=8),
triton.Config({'BT': 256}, num_warps=2),
triton.Config({'BT': 256}, num_warps=4),
triton.Config({'BT': 256}, num_warps=8)
],
key=['D']
)
@triton.jit
def logsigmoid_fwd_kernel(
x,
y,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr
):
i = tl.program_id(0)
o_i = i * BT + tl.arange(0, BT)
p_x = x + o_i
p_y = y + o_i
mask = o_i < T
# [D,]
b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)
b_m = tl.minimum(0., b_x)
b_z = 1. + tl.exp(-tl.abs(b_x))
b_y = b_m - tl.log(b_z)
tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 16}, num_warps=4),
triton.Config({'BT': 16}, num_warps=8),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=8),
triton.Config({'BT': 64}, num_warps=2),
triton.Config({'BT': 64}, num_warps=4),
triton.Config({'BT': 64}, num_warps=8),
triton.Config({'BT': 128}, num_warps=2),
triton.Config({'BT': 128}, num_warps=4),
triton.Config({'BT': 128}, num_warps=8),
triton.Config({'BT': 256}, num_warps=2),
triton.Config({'BT': 256}, num_warps=4),
triton.Config({'BT': 256}, num_warps=8)
],
key=['D']
)
@triton.jit
def logsigmoid_bwd_kernel(
x,
dx,
dy,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr
):
i = tl.program_id(0)
o_i = i * BT + tl.arange(0, BT)
p_x = x + o_i
p_dx = dx + o_i
p_dy = dy + o_i
mask = o_i < T
# [D,]
b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)
b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32)
b_dx = b_dy * (1. - tl.sigmoid(b_x))
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
class LogSigmoidFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, x):
T, D = x.numel(), x.shape[-1]
y = torch.empty_like(x)
logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D)
ctx.save_for_backward(x,)
return y
@staticmethod
@contiguous
def backward(ctx, dy):
x, = ctx.saved_tensors
T, D = x.numel(), x.shape[-1]
dx = torch.empty_like(x)
logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D)
return dx
logsigmoid = LogSigmoidFunction.apply
swish_fwd_codestring = """
template <typename T> T swish_fwd(T x) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
return float(x) * x_sigmoid;
}
"""
swish_bwd_codestring = """
template <typename T> T swish_bwd(T x, T g) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x));
}
"""
swish_fwd = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring)
swish_bwd = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring)
class SwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_fwd(x)
@staticmethod
def backward(ctx, dout):
x, = ctx.saved_tensors
return swish_bwd(x, dout)
swish = SwishFunction.apply
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(y, bias):
x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_bwd(g, y, bias):
"""Assume that y has shape (B, D) and bias has shape (D)"""
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_bwd(grad_output, input, bias)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return (ff * g).to(dtype=x.dtype)
class FastGeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return gelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input)
return tmp
fast_gelu_impl = FastGeLUFunction.apply
@torch.jit.script
def relu_bwd(g, x):
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_fwd(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_bwd(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
class SquaredReLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return sqrelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return sqrelu_bwd(grad_output, input)
sqrelu = SquaredReLUFunction.apply
swiglu_fwd_codestring = """
template <typename T> T swiglu_fwd(T x, T y) {
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
}
"""
swiglu_bwd_codestring = """
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = float(x) * x_sigmoid * float(g);
}
"""
swiglu_bwd_with_output_codestring = """
template <typename T> T swiglu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
float x_swish = float(x) * x_sigmoid;
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = x_swish * float(g);
z = x_swish * float(y);
}
"""
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
swiglu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_with_output_codestring, num_outputs=3)
class SwiGLUFunction(torch.autograd.Function):
r"""
Swish-Gated Linear Unit (SwiGLU) function.
.. math::
\text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y
"""
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return swiglu_fwd(x, y)
@staticmethod
def backward(ctx, dout):
x, y = ctx.saved_tensors
return swiglu_bwd(x, y, dout)
class SwiGLULinearFunction(torch.autograd.Function):
r"""
Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation.
.. math::
\text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b
This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory.
"""
@staticmethod
def forward(ctx, x, y, weight, bias):
z = swiglu_fwd(x, y)
out = F.linear(z.to(weight.dtype), weight, bias)
# We don't store z, will be recomputed in the backward pass to save memory
ctx.save_for_backward(x, y, weight)
ctx.linear_bias_is_none = bias is None
return out
@staticmethod
def backward(ctx, dout, *args):
x, y, weight = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dz = F.linear(dout, weight.t()).view_as(x)
dx, dy, z = swiglu_bwd_with_output(x, y, dz)
dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1]))
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
return dx, dy, dlinear_weight, dlinear_bias
swiglu = SwiGLUFunction.apply
swiglu_linear = SwiGLULinearFunction.apply
ACT2FN = {
'relu': F.relu,
'sigmoid': sigmoid,
'logsigmoid': logsigmoid,
'silu': swish,
'swish': swish,
'sqrelu': sqrelu,
'gelu': fast_gelu_impl,
'bias_gelu': bias_gelu_impl,
}

View File

@ -0,0 +1,336 @@
# -*- coding: utf-8 -*-
# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
import math
import warnings
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from fla.modules.activations import ACT2FN
from fla.utils import checkpoint
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn = None
causal_conv1d_update = None
def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
if k_rev is not None:
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
k_f = k_f + k_rev_f.conj()
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3:
k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
out = y + u
if gelu:
out = F.gelu(out)
if dropout_mask is not None:
return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
else:
return out.to(dtype=u.dtype)
@checkpoint
def proj_then_conv1d(
x: torch.Tensor,
proj_weight: torch.Tensor,
conv1d_weight: torch.Tensor,
conv1d_bias: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
# We do matmul and transpose BLH -> HBL at the same time
x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2])
if causal_conv1d_fn is None:
raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
if cache is None:
x = causal_conv1d_fn(
x=x,
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
bias=conv1d_bias,
activation="silu",
).transpose(1, 2)
else:
assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
x = x.squeeze(-1)
x = causal_conv1d_update(
x=x,
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
bias=conv1d_bias,
cache=cache,
activation="silu",
)
return x
class ShortConvolution(nn.Conv1d):
"""
Simple wrapper around `nn.Conv1d` that accepts dimension last.
"""
def __init__(
self,
hidden_size: int,
kernel_size: int,
bias: bool = False,
activation: Optional[str] = 'silu',
use_causal_conv: Optional[bool] = True
):
super().__init__(in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=kernel_size,
groups=hidden_size,
bias=bias,
padding=kernel_size - 1)
self.hidden_size = hidden_size
self.activation = None
if activation is not None:
assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
self.activation = activation
if use_causal_conv:
if causal_conv1d_fn is None:
warnings.warn("Please install `causal-conv1d` to use causal convolutions, setting `use_causal_conv` to False.")
use_causal_conv = False
self.use_causal_conv = use_causal_conv
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.padding_mode != 'zeros':
s += ', padding_mode={padding_mode}'
if self.activation is not None:
s += ', activation={activation}'
if not self.use_causal_conv:
s += ', use_causal_conv={use_causal_conv}'
return s.format(**self.__dict__)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x (`torch.Tensor`):
Tensor of shape `[batch_size, seq_len, hidden_size]`
mask (`Optional[torch.Tensor]`):
Attention mask dealing with padded positions.
cache (`Optional[torch.Tensor]`):
Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`,
Returns:
Tensor of shape `[batch_size, seq_len, hidden_size]`. The `cache` (if provided) is updated inplace.
"""
if mask is not None:
x = x.mul_(mask.unsqueeze(-1))
if cache is not None and x.shape[1] == 1:
return self.step(x, cache)
x = rearrange(x, "b l d -> b d l")
# Update state (B D W)
if cache is not None:
cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0)))
if self.use_causal_conv:
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.weight, "d 1 w -> d w"),
bias=self.bias,
activation=self.activation,
)
else:
x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
if self.activation is not None:
x = ACT2FN[self.activation](x)
return rearrange(x, "b d l -> b l d")
def step(
self,
x: torch.Tensor,
cache: torch.Tensor
):
assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now"
x = x.squeeze(1)
if self.use_causal_conv:
x = causal_conv1d_update(
x=x,
conv_state=cache,
weight=rearrange(self.weight, "d 1 w -> d w"),
bias=self.bias,
activation=self.activation,
)
else:
dtype = x.dtype
cache.copy_(torch.roll(cache, shifts=-1, dims=-1))
cache[:, :, -1] = x
x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
if self.bias is not None:
x = x + self.bias
if self.activation is not None:
x = ACT2FN[self.activation](x).to(dtype=dtype)
return x.unsqueeze(1)
@property
def state_size(self) -> int:
return self.hidden_size * self.kernel_size
class LongConvolution(nn.Module):
"""
LongConvolution applies a convolution operation on the input tensor using a fixed
filter of length l_max.
The filter is learned during training and is applied using FFT convolution.
Args:
hidden_size (int): The number of expected features in the input and output.
l_max (int): The maximum sequence length.
Returns:
y: (b, l, d) tensor
"""
def __init__(
self,
hidden_size: int,
l_max: int,
**kwargs,
):
"""
Initializes the LongConvolution module.
Args:
hidden_size (int): The number of expected features in the input and output.
l_max (int): The maximum sequence length.
"""
super().__init__()
self.hidden_size = hidden_size
self.filter = nn.Parameter(torch.randn(self.hidden_size, l_max), requires_grad=True)
def forward(self, x: torch.Tensor, *args, **kwargs):
"""
Applies the LongConvolution operation on the input tensor.
Args:
x: (b, l, d) tensor
Returns:
y: (b, l, d) tensor
"""
x = x.transpose(1, 2)
y = fft_conv(x, self.filter, dropout_mask=None, gelu=False)
y = y.transpose(1, 2)
return y.to(dtype=x.dtype)
class PositionalEmbedding(nn.Module):
def __init__(self, emb_dim: int, seq_len: int, **kwargs):
"""Complex exponential positional embeddings for implicit long convolution filters."""
super().__init__()
self.seq_len = seq_len
# The time embedding fed to the filteres is normalized so that t_f = 1
t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
if emb_dim > 1:
bands = (emb_dim - 1) // 2
# To compute the right embeddings we use the "proper" linspace
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
z = torch.exp(-1j * f * w)
z = torch.cat([t, z.real, z.imag], dim=-1)
self.z = nn.Parameter(z, requires_grad=False)
def forward(self, L):
return self.z[:, :L]
class ImplicitLongConvolution(nn.Module):
"""
Long convolution with implicit filter parameterized by an MLP.
Args:
hidden_size (int):
The number of expected features in the input and output.
l_max (int):
The maximum sequence length.
d_emb (Optional[int]):
The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine).
Defaults to 3.
d_hidden (Optional[int]):
The number of features in the hidden layer of the MLP. Defaults to 16.
Attributes:
pos_emb (`PositionalEmbedding`): The positional embedding layer.
mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter.
"""
def __init__(
self,
hidden_size: int,
l_max: int,
d_emb: int = 3,
d_hidden: int = 16,
**kwargs,
):
"""
Long convolution with implicit filter parameterized by an MLP.
"""
super().__init__()
self.hidden_size = hidden_size
self.d_emb = d_emb
assert (
d_emb % 2 != 0 and d_emb >= 3
), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)"
self.pos_emb = PositionalEmbedding(d_emb, l_max)
# final linear layer
self.mlp = nn.Sequential(
nn.Linear(d_emb, d_hidden),
torch.nn.ReLU(),
nn.Linear(d_hidden, hidden_size),
)
def filter(self, seq_len: int, *args, **kwargs):
k = self.mlp(self.pos_emb(seq_len))
return k.transpose(1, 2)
def forward(self, x: torch.Tensor, *args, **kwargs):
"""
Args:
x: (b, l, d) tensor
Returns:
y: (b, l, d) tensor
"""
x = x.transpose(1, 2)
k = self.filter(x.shape[-1])
y = fft_conv(x, k, dropout_mask=None, gelu=False)
y = y.transpose(1, 2)
return y.to(dtype=x.dtype)

View File

@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from fla.modules.layernorm import layer_norm_fn
from fla.utils import checkpoint
@checkpoint
def flatten_diag_outer_product(x, y):
z = torch.einsum("...i,...j->...ij", x, y)
N = z.size(-1)
indicies = torch.triu_indices(N, N)
return z[..., indicies[0], indicies[1]]
@checkpoint
def flatten_diag_outer_product_off1(x, y):
z = torch.einsum("...i,...j->...ij", x, y)
N = z.size(-1)
indicies = torch.triu_indices(N, N, 1)
indices2 = torch.arange(0, N)
return z[..., indicies[0], indicies[1]], z[..., indices2, indices2]
def is_power_of_2(n):
return (n & (n - 1) == 0) and n != 0
class HedgehogFeatureMap(nn.Module):
r"""
Hedgehog feature map as introduced in
`The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry <https://arxiv.org/abs/2402.04347>`_
"""
def __init__(
self,
head_dim: int
) -> HedgehogFeatureMap:
super().__init__()
# Trainable map
self.layer = nn.Linear(head_dim, head_dim)
self.init_weights_()
def init_weights_(self):
"""Initialize trainable map as identity"""
with torch.no_grad():
identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float)
self.layer.weight.copy_(identity.to(self.layer.weight))
nn.init.zeros_(self.layer.bias)
def forward(self, x: torch.Tensor):
x = self.layer(x) # shape b, h, l, d
return torch.cat([2*x, -2*x], dim=-1).softmax(-1)
class T2RFeatureMap(nn.Module):
r"""
Simple linear mapping feature map as in
`Finetuning Pretrained Transformers into RNNs <https://arxiv.org/abs/2103.13076>`_
"""
def __init__(
self,
head_dim: int,
dot_dim: int = None
) -> T2RFeatureMap:
super().__init__()
# Trainable map
if dot_dim is None:
dot_dim = head_dim
self.layer = nn.Linear(head_dim, dot_dim)
def forward(self, x: torch.Tensor):
return self.layer(x).relu()
class DPFPFeatureMap(nn.Module):
r"""
Deterministic Parameter-Free Projection (DPFP) feature map in
`Linear Transformers Are Secretly Fast Weight Programmers <https://arxiv.org/abs/2102.11174>`_
"""
def __init__(
self,
head_dim: int,
nu: int = 4
) -> DPFPFeatureMap:
super().__init__()
self.nu = nu
def forward(self, x: torch.Tensor):
x = torch.cat([x.relu(), -x.relu()], dim=-1)
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1)
x_repeat = torch.cat([x] * self.nu, dim=-1)
return x_repeat * x_rolled
class HadamardFeatureMap(nn.Module):
def __init__(
self,
head_dim: int
) -> HadamardFeatureMap:
super().__init__()
# Trainable map
self.layer1 = nn.Linear(head_dim, head_dim)
self.layer2 = nn.Linear(head_dim, head_dim)
def forward(self, x: torch.Tensor):
return self.layer1(x) * self.layer2(x)
class LearnableOuterProductFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
feature_dim: int
) -> LearnableOuterProductFeatureMap:
super().__init__()
# Trainable map
self.layer1 = nn.Linear(head_dim, feature_dim, bias=False)
self.layer2 = nn.Linear(head_dim, feature_dim, bias=False)
self.normalizer = feature_dim ** -0.5
def forward(self, x: torch.Tensor):
return flatten_diag_outer_product(self.layer1(x), self.layer2(x))
class LearnablePolySketchNonNegativeFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
sketch_size: Optional[int] = None,
degree: Optional[int] = 2
) -> LearnablePolySketchNonNegativeFeatureMap:
super().__init__()
assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2"
self.head_dim = head_dim
self.sketch_size = sketch_size if sketch_size is not None else head_dim
self.degree = degree
self.gamma = nn.Parameter(torch.ones(head_dim))
self.beta = nn.Parameter(torch.zeros(head_dim))
# NOTE: the sketch layers defined here are quite different from the original paper
# currently we simply use linear layers without any non-linear activations
self.sketches1 = nn.ModuleList([
nn.Linear(head_dim, sketch_size, bias=False),
*[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
])
self.sketches2 = nn.ModuleList([
nn.Linear(head_dim, sketch_size, bias=False),
*[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
])
def forward(self, x: torch.Tensor):
# Section 2.1
x = layer_norm_fn(x, self.gamma, self.beta)
# first map the input to sketch size with learnable parameters
x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5
for i in range(1, int(math.log2(self.degree)) - 1):
x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5
# do sketch mapping for log2(p) - 1 times in total
# do p=2 mapping to ensure non-negativity
return flatten_diag_outer_product(x, x)
class TaylorFeatureMap(nn.Module):
def __init__(
self,
head_dim: int
) -> TaylorFeatureMap:
super().__init__()
self.head_dim = head_dim
self.r2 = math.sqrt(2)
self.rd = math.sqrt(self.head_dim)
self.rrd = math.sqrt(self.rd)
def forward(self, x: torch.Tensor):
x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1)
class RebasedFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
use_gamma: Optional[bool] = True,
use_beta: Optional[bool] = True,
normalize: Optional[bool] = True
) -> RebasedFeatureMap:
super().__init__()
self.head_dim = head_dim
self.use_gamma = use_gamma
self.use_beta = use_beta
self.normalize = normalize
self.gamma = None
self.beta = None
if use_gamma:
self.gamma = nn.Parameter(torch.ones(head_dim))
if use_beta:
self.beta = nn.Parameter(torch.zeros(head_dim))
def forward(self, x: torch.Tensor, flatten: Optional[bool] = True):
if self.use_beta and self.use_gamma and self.normalize:
x = layer_norm_fn(x, self.gamma, self.beta)
elif self.normalize:
x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta)
elif self.use_gamma and self.use_beta:
x = torch.addcmul(self.beta, x, self.gamma)
elif self.use_gamma:
x = x.mul(self.gamma)
else:
raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, "
f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)")
if not flatten:
return x
x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
# rebased use learnable parameters to approximate any quadratic function
return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1)

View File

@ -0,0 +1,398 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
from typing import Tuple
import torch
import torch.nn as nn
import triton
import triton.language as tl
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_fwd_kernel(
loss_ptr, # data ptrs
lse_ptr,
z_loss_ptr,
logits_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
n_rows,
logits_row_stride, # strides
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
max_logits = tl.max(logits, 0)
if HAS_SMOOTHING:
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
if label_idx == ignored_index:
loss = 0.0
z_loss = 0.0
else:
label_idx -= class_start_idx
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
n_cols, (col_block_idx + 1) * BLOCK_SIZE
):
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
if HAS_SMOOTHING:
loss = (
(lse if not SPLIT else 0.0)
- smoothing * sum_logits / total_classes
- (1 - smoothing) * logits_label
)
else:
loss = (lse if not SPLIT else 0.0) - logits_label
else:
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
if HAS_SMOOTHING:
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
else:
loss = 0.0
if not SPLIT:
z_loss = lse_square_scale * lse * lse
loss += z_loss
else:
z_loss = 0.0
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
if not SPLIT:
tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_bwd_kernel(
dlogits_ptr, # data ptrs
dloss_ptr,
logits_ptr,
lse_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
logits_row_stride, # strides
dlogits_row_stride,
dloss_row_stride,
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx != ignored_index:
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs += 2.0 * lse_square_scale * lse * probs
label_idx -= class_start_idx
if HAS_SMOOTHING:
smooth_negative = smoothing / total_classes
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
else:
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
class CrossEntropyLossFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
labels,
smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
n_rows, n_cols = logits.shape
assert labels.shape == (n_rows,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
total_classes = world_size * n_cols
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
class_start_idx = rank * n_cols
if logits.stride(-1) != 1:
logits = logits.contiguous()
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
MAX_BLOCK_SIZE = 64 * 1024
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
num_warps = (
4
if BLOCK_SIZE < 2048
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
)
# We may split the lse computation across multiple blocks, then do a reduction
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
# where having just one thread block processing more than 64k elements is slow.
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_fwd_kernel[(n_rows, n_splits)](
losses, # data ptrs
lse,
z_losses,
logits,
labels,
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
total_classes,
class_start_idx,
n_cols, # shapes
n_rows,
logits.stride(0), # strides
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
SPLIT=split,
)
if split:
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if n_splits > 1:
lse = torch.logsumexp(lse, dim=0)
losses = losses.sum(dim=0)
if world_size > 1:
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
lse = torch.logsumexp(lse_allgather, dim=0)
handle_losses.wait()
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's smoothing=0.1, the total losses are
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
# Again, we just have to add the (global) lse.
losses += lse
if lse_square_scale != 0.0:
z_losses = lse_square_scale * lse.square()
z_losses.masked_fill_(labels == ignored_index, 0.0)
losses += z_losses
else:
z_losses = torch.zeros_like(losses)
losses.masked_fill_(labels == ignored_index, 0.0)
ctx.save_for_backward(logits, lse, labels)
ctx.mark_non_differentiable(z_losses)
ctx.smoothing = smoothing
ctx.logit_scale = logit_scale
ctx.lse_square_scale = lse_square_scale
ctx.ignored_index = ignored_index
ctx.total_classes = total_classes
ctx.class_start_idx = class_start_idx
ctx.inplace_backward = inplace_backward
return losses, z_losses
@staticmethod
def backward(ctx, grad_losses, grad_z_losses):
del grad_z_losses # z_losses are only for logging.
logits, lse, labels = ctx.saved_tensors
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
n_rows, n_cols = logits.shape
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_bwd_kernel[grid](
dlogits, # data ptrs
grad_losses,
logits,
lse,
labels,
ctx.smoothing,
ctx.logit_scale,
ctx.lse_square_scale,
ctx.ignored_index,
ctx.total_classes,
ctx.class_start_idx,
n_cols, # shapes
logits.stride(0), # strides
dlogits.stride(0),
grad_losses.stride(0),
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
)
return dlogits, None, None, None, None, None, None, None, None
def cross_entropy_loss(
logits: torch.Tensor,
labels: torch.Tensor,
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
ignored_index=-100,
inplace_backward: bool = False,
process_group=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
Returns:
losses: (batch,), float
z_losses: (batch,), float
"""
return CrossEntropyLossFunction.apply(
logits,
labels,
label_smoothing,
logit_scale,
lse_square_scale,
ignored_index,
inplace_backward,
process_group,
)
class FusedCrossEntropyLoss(nn.Module):
def __init__(
self,
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
inplace_backward=False,
process_group=None,
return_z_loss=False,
):
"""
Arguments:
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
label_smoothing: float
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
return_z_loss: bool. If True, we return the component of the loss contributed by
the lse_square_scale value. This value is only for logging and does not support
backprop.
"""
super().__init__()
if reduction not in ["mean", "none", "sum"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.logit_scale = logit_scale
self.lse_square_scale = lse_square_scale
self.inplace_backward = inplace_backward
self.process_group = process_group
self.return_z_loss = return_z_loss
def forward(self, input, target):
"""
Arguments:
input: (batch, vocab_size)
target: (batch,)
Returns:
losses: (batch,) if reduction is 'none', else (1,), dtype float
z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
"""
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
loss, z_loss = cross_entropy_loss(
input,
target,
label_smoothing=self.label_smoothing,
logit_scale=self.logit_scale,
lse_square_scale=self.lse_square_scale,
ignored_index=self.ignore_index,
inplace_backward=self.inplace_backward,
process_group=self.process_group,
)
if self.reduction == "mean":
loss = loss.sum() / (target != self.ignore_index).sum()
elif self.reduction == "sum":
loss = loss.sum()
else:
loss = loss
if not self.return_z_loss:
return loss
if self.reduction == "mean":
z_loss = z_loss.sum() / (target != self.ignore_index).sum()
elif self.reduction == "sum":
z_loss = z_loss.sum()
else:
z_loss = z_loss
return loss, z_loss

View File

@ -0,0 +1,889 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
# Implement residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from fla.utils import contiguous
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
dtype
)
return out if not prenorm else (out, x)
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + \
bias if bias is not None else (x * rstd * weight)
out = out.to(dtype)
return out if not prenorm else (out, x)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
O, # pointer to the gate
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
RESIDUAL_OUT, # pointer to the residual
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
N, # number of columns in X
eps, # epsilon to avoid division by zero
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
O += row * stride_x_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols <
N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w if HAS_WEIGHT else x_hat
if HAS_BIAS:
y = y + b
# Swish output gate
o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)
y = y * o * tl.sigmoid(o)
# Write output
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(
x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
):
if residual is not None:
residual_dtype = residual.dtype
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
assert residual_out.stride(-1) == 1
else:
residual_out = None
mean = torch.empty((M,), dtype=torch.float32,
device="cuda") if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)](
x,
o,
y,
weight,
bias,
residual,
residual_out,
mean,
rstd,
x.stride(0),
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
N,
eps,
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return y, mean, rstd, residual_out if residual_out is not None else x
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
X, # pointer to the input
O, # pointer to the gate
W, # pointer to the weights
B, # pointer to the biases
Y, # pointer to the output to be recomputed
DY, # pointer to the output gradient
DX, # pointer to the input gradient
DO, # pointer to the gate gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
DRESIDUAL,
DRESIDUAL_IN,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_dy_row,
stride_dx_row,
stride_dres_row,
stride_dres_in_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
rows_per_program,
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_DRESIDUAL: tl.constexpr,
STORE_DRESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, BLOCK_N)
mask = cols < N
X += row_start * stride_x_row
O += row_start * stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += row_start * stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += row_start * stride_dres_in_row
DY += row_start * stride_dy_row
DX += row_start * stride_dx_row
DO += row_start * stride_dx_row
if RECOMPUTE_OUTPUT:
Y += row_start * stride_y_row
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
if RECOMPUTE_OUTPUT and HAS_BIAS:
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
if HAS_BIAS:
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
row_end = min((row_block_id + 1) * rows_per_program, M)
for row in range(row_start, row_end):
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
o = tl.load(O + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
xhat = tl.where(mask, xhat, 0.0)
y = xhat * w if HAS_WEIGHT else xhat
if HAS_BIAS:
y = y + b
if RECOMPUTE_OUTPUT:
tl.store(Y + cols, y, mask=mask)
sigmoid_o = tl.sigmoid(o)
do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o))
dy = dy * o * sigmoid_o
wdy = dy
if HAS_WEIGHT:
wdy = dy * w
dw += dy * xhat
if HAS_BIAS:
db += dy
if not IS_RMS_NORM:
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
else:
c1 = tl.sum(xhat * wdy, axis=0) / N
dx = (wdy - xhat * c1) * rstd
if HAS_DRESIDUAL:
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
dx += dres
# Write dx
if STORE_DRESIDUAL:
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
tl.store(DX + cols, dx, mask=mask)
tl.store(DO + cols, do, mask=mask)
X += stride_x_row
O += stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += stride_dres_in_row
if RECOMPUTE_OUTPUT:
Y += stride_y_row
DY += stride_dy_row
DX += stride_dx_row
DO += stride_dx_row
if HAS_WEIGHT:
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
if HAS_BIAS:
tl.store(DB + row_block_id * N + cols, db, mask=mask)
def _layer_norm_bwd(
dy,
x,
o,
weight,
bias,
eps,
mean,
rstd,
dresidual=None,
has_residual=False,
is_rms_norm=False,
x_dtype=None,
recompute_output=False,
):
M, N = x.shape
assert x.stride(-1) == 1
assert dy.stride(-1) == 1
assert dy.shape == (M, N)
if dresidual is not None:
assert dresidual.stride(-1) == 1
assert dresidual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
dx = (
torch.empty_like(x)
if x_dtype is None
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
do = (
torch.empty_like(o)
if x_dtype is None
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = (
torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
if weight is not None
else None
)
_db = (
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
if bias is not None
else None
)
rows_per_program = math.ceil(M / sm_count)
grid = (sm_count,)
with torch.cuda.device(x.device.index):
_layer_norm_bwd_kernel[grid](
x,
o,
weight,
bias,
y,
dy,
dx,
do,
_dw,
_db,
dresidual,
dresidual_in,
mean,
rstd,
x.stride(0),
0 if not recompute_output else y.stride(0),
dy.stride(0),
dx.stride(0),
dresidual.stride(0) if dresidual is not None else 0,
dresidual_in.stride(0) if dresidual_in is not None else 0,
M,
N,
eps,
rows_per_program,
is_rms_norm,
BLOCK_N,
dresidual is not None,
dresidual_in is not None,
weight is not None,
bias is not None,
)
dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
db = _db.sum(0).to(bias.dtype) if bias is not None else None
# Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype:
dresidual_in = dx
return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y)
class LayerNormSwishGateFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
o,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
o_shape_og = o.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
o = o.reshape(-1, o.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x, o, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
)
ctx.save_for_backward(residual_out, o, weight, bias, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.o_shape_og = o_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
y = y.reshape(x_shape_og)
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dy, *args):
x, o, weight, bias, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1])
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, do, dw, db, dresidual_in = _layer_norm_bwd(
dy,
x,
o,
weight,
bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
)
return (
dx.reshape(ctx.x_shape_og),
do.reshape(ctx.o_shape_og),
dw,
db,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
class LayerNormSwishGateLinearFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
o_shape_og = o.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
o = o.reshape(-1, o.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x,
o,
norm_weight,
norm_bias,
eps,
residual,
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm
)
y = y.reshape(x_shape_og)
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
linear_weight = linear_weight.to(dtype)
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
# We don't store y, will be recomputed in the backward pass to save memory
ctx.save_for_backward(residual_out, o, norm_weight, norm_bias, linear_weight, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.o_shape_og = o_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
ctx.linear_bias_is_none = linear_bias is None
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dout, *args):
x, o, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dy = F.linear(dout, linear_weight.t())
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, do, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dy,
x,
o,
norm_weight,
norm_bias,
ctx.eps,
mean,
rstd,
dresidual=dresidual,
has_residual=ctx.has_residual,
is_rms_norm=ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
return (
dx.reshape(ctx.x_shape_og),
do.reshape(ctx.o_shape_og),
dnorm_weight,
dnorm_bias,
dlinear_weight,
dlinear_bias,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
def layer_norm_swish_gate_fn(
x,
o,
weight,
bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateFn.apply(
x,
o,
weight,
bias,
residual,
eps,
prenorm,
residual_in_fp32,
False
)
def rms_norm_swish_gate_fn(
x,
o,
weight,
bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateFn.apply(
x,
o,
weight,
bias,
residual,
eps,
prenorm,
residual_in_fp32,
True
)
def layer_norm_swish_gate_linear_fn(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateLinearFn.apply(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual,
eps,
prenorm,
residual_in_fp32,
False
)
def rms_norm_swish_gate_linear_fn(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateLinearFn.apply(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual,
eps,
prenorm,
residual_in_fp32,
True
)
class FusedLayerNormSwishGate(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedLayerNormSwishGate:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_swish_gate_fn(
x,
o,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class FusedRMSNormSwishGate(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedRMSNormSwishGate:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_swish_gate_fn(
x,
o,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class FusedLayerNormSwishGateLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedLayerNormSwishGateLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_swish_gate_linear_fn(
x,
o,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class FusedRMSNormSwishGateLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedRMSNormSwishGateLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_swish_gate_linear_fn(
x,
o,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)

216
finetune/lora/v6/fla/modules/l2norm.py vendored Normal file
View File

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_fwd, custom_bwd
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _l2_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
stride_x_row, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_x_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0)
rstd = 1 / tl.sqrt(var + eps)
# tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
y = x * rstd
# Write output
tl.store(Y + cols, y, mask=mask)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _l2_norm_bwd_kernel(
X, # pointer to the input
# Y, # pointer to the output to be recomputed
DY, # pointer to the output gradient
DX, # pointer to the input gradient
stride_x_row, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
DX += row * stride_x_row
DY += row * stride_x_row
# Y += row * stride_y_row
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x, 0.0)
var = tl.sum(x * x)
rstd = 1 / tl.sqrt(var + eps)
# tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
# y = x * rstd
dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)
dy = tl.where(cols < N, dy, 0.0)
# dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
tl.store(DX + cols, dx, mask=mask)
def _l2_norm_fwd(
x, eps=1e-6
):
x_shape_og = x.shape
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
M, N = x.shape
assert x.stride(-1) == 1
# allocate output
y = torch.empty_like(x)
assert y.stride(-1) == 1
N = x.shape[-1]
M = x.shape[0]
# rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_l2_norm_fwd_1pass_kernel[(M,)](
x,
y,
x.stride(0),
N,
eps,
# is_rms_norm,
BLOCK_N,
# residual is not None,
# residual_out is not None,
# bias is not None,
)
return y.reshape(x_shape_og)
def _l2_norm_bwd(
x, dy, eps=1e-5,
):
x_shape_og = x.shape
x = x.reshape(-1, dy.shape[-1])
dy = dy.reshape(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()
assert dy.shape == x.shape
# allocate output
dx = torch.empty_like(x)
N = x.shape[-1]
M = x.shape[0]
assert x.stride(-1) == 1
assert dy.stride(-1) == 1
# rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_l2_norm_bwd_kernel[(M,)](
x,
dy,
dx,
x.stride(0),
N,
eps,
BLOCK_N,
)
return dx.reshape(x_shape_og)
class L2NormFN(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
eps=1e-6,
):
# reshape input data into 2D tensor
y = _l2_norm_fwd(x, eps)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.x_dtype = x.dtype
ctx.save_for_backward(x)
return y
@staticmethod
def backward(ctx, dy, *args):
x, = ctx.saved_tensors
dx = _l2_norm_bwd(
x,
dy,
ctx.eps,
)
return (
dx,
None
)
l2_norm_fn = L2NormFN.apply
if __name__ == '__main__':
x = torch.rand(10, 10, 100).cuda().requires_grad_(True)
y = torch.nn.functional.normalize(x, dim=-1, p=2)
dy = torch.rand_like(y)
y.backward(dy, retain_graph=True)
x_grad, x.grad = x.grad, None
y2 = l2_norm_fn(x, 1e-6)
print((y-y2).abs().max())
y2.backward(dy, retain_graph=True)
x_grad2, x.grad = x.grad, None
print((x_grad2-x_grad).abs().max())
breakpoint()

View File

@ -0,0 +1,802 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
# Implement residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from fla.utils import contiguous
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
dtype
)
return out if not prenorm else (out, x)
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + \
bias if bias is not None else (x * rstd * weight)
out = out.to(dtype)
return out if not prenorm else (out, x)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
RESIDUAL_OUT, # pointer to the residual
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
N, # number of columns in X
eps, # epsilon to avoid division by zero
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols <
N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w if HAS_WEIGHT else x_hat
if HAS_BIAS:
y = y + b
# Write output
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
):
if residual is not None:
residual_dtype = residual.dtype
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
assert residual_out.stride(-1) == 1
else:
residual_out = None
mean = torch.empty((M,), dtype=torch.float32,
device="cuda") if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)](
x,
y,
weight,
bias,
residual,
residual_out,
mean,
rstd,
x.stride(0),
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
N,
eps,
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return y, mean, rstd, residual_out if residual_out is not None else x
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
X, # pointer to the input
W, # pointer to the weights
B, # pointer to the biases
Y, # pointer to the output to be recomputed
DY, # pointer to the output gradient
DX, # pointer to the input gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
DRESIDUAL,
DRESIDUAL_IN,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_dy_row,
stride_dx_row,
stride_dres_row,
stride_dres_in_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
rows_per_program,
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_DRESIDUAL: tl.constexpr,
STORE_DRESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, BLOCK_N)
mask = cols < N
X += row_start * stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += row_start * stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += row_start * stride_dres_in_row
DY += row_start * stride_dy_row
DX += row_start * stride_dx_row
if RECOMPUTE_OUTPUT:
Y += row_start * stride_y_row
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
if RECOMPUTE_OUTPUT and HAS_BIAS:
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
if HAS_BIAS:
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
row_end = min((row_block_id + 1) * rows_per_program, M)
for row in range(row_start, row_end):
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
xhat = tl.where(mask, xhat, 0.0)
if RECOMPUTE_OUTPUT:
y = xhat * w if HAS_WEIGHT else xhat
if HAS_BIAS:
y = y + b
tl.store(Y + cols, y, mask=mask)
wdy = dy
if HAS_WEIGHT:
wdy = dy * w
dw += dy * xhat
if HAS_BIAS:
db += dy
if not IS_RMS_NORM:
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
else:
c1 = tl.sum(xhat * wdy, axis=0) / N
dx = (wdy - xhat * c1) * rstd
if HAS_DRESIDUAL:
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
dx += dres
# Write dx
if STORE_DRESIDUAL:
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
tl.store(DX + cols, dx, mask=mask)
X += stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += stride_dres_in_row
if RECOMPUTE_OUTPUT:
Y += stride_y_row
DY += stride_dy_row
DX += stride_dx_row
if HAS_WEIGHT:
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
if HAS_BIAS:
tl.store(DB + row_block_id * N + cols, db, mask=mask)
def _layer_norm_bwd(
dy,
x,
weight,
bias,
eps,
mean,
rstd,
dresidual=None,
has_residual=False,
is_rms_norm=False,
x_dtype=None,
recompute_output=False,
):
M, N = x.shape
assert x.stride(-1) == 1
assert dy.stride(-1) == 1
assert dy.shape == (M, N)
if dresidual is not None:
assert dresidual.stride(-1) == 1
assert dresidual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
dx = (
torch.empty_like(x)
if x_dtype is None
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
dresidual_in = torch.empty_like(
x) if has_residual and dx.dtype != x.dtype else None
y = torch.empty(M, N, dtype=dy.dtype,
device=dy.device) if recompute_output else None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = (
torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
if weight is not None
else None
)
_db = (
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
if bias is not None
else None
)
rows_per_program = math.ceil(M / sm_count)
grid = (sm_count,)
with torch.cuda.device(x.device.index):
_layer_norm_bwd_kernel[grid](
x,
weight,
bias,
y,
dy,
dx,
_dw,
_db,
dresidual,
dresidual_in,
mean,
rstd,
x.stride(0),
0 if not recompute_output else y.stride(0),
dy.stride(0),
dx.stride(0),
dresidual.stride(0) if dresidual is not None else 0,
dresidual_in.stride(0) if dresidual_in is not None else 0,
M,
N,
eps,
rows_per_program,
is_rms_norm,
BLOCK_N,
dresidual is not None,
dresidual_in is not None,
weight is not None,
bias is not None,
)
dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
db = _db.sum(0).to(bias.dtype) if bias is not None else None
# Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype:
dresidual_in = dx
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
class LayerNormFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
)
ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
y = y.reshape(x_shape_og)
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dy, *args):
x, weight, bias, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1])
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dw, db, dresidual_in = _layer_norm_bwd(
dy,
x,
weight,
bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
)
return (
dx.reshape(ctx.x_shape_og),
dw,
db,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
def layer_norm_fn(
x,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
def rms_norm_fn(
x,
weight,
bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
class LayerNorm(nn.Module):
def __init__(
self,
hidden_size: int,
elementwise_affine: bool = True,
eps: float = 1e-5
) -> LayerNorm:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_fn(
x,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
elementwise_affine: bool = True,
eps: float = 1e-5
) -> RMSNorm:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_fn(
x,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
)
class LayerNormLinearFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x,
norm_weight,
norm_bias,
eps,
residual,
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm,
)
y = y.reshape(x_shape_og)
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
linear_weight = linear_weight.to(dtype)
linear_bias = linear_bias.to(
dtype) if linear_bias is not None else None
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
# We don't store y, will be recomputed in the backward pass to save memory
ctx.save_for_backward(residual_out, norm_weight,
norm_bias, linear_weight, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
ctx.linear_bias_is_none = linear_bias is None
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dout, *args):
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dy = F.linear(dout, linear_weight.t())
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dy,
x,
norm_weight,
norm_bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
return (
dx.reshape(ctx.x_shape_og),
dnorm_weight,
dnorm_bias,
dlinear_weight,
dlinear_bias,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
def layer_norm_linear_fn(
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
return LayerNormLinearFn.apply(
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual,
eps,
prenorm,
residual_in_fp32,
is_rms_norm,
)
class LayerNormLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> LayerNormLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_linear_fn(
x,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=False
)
class RMSNormLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> RMSNormLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_linear_fn(
x,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=True
)

310
finetune/lora/v6/fla/modules/rotary.py vendored Normal file
View File

@ -0,0 +1,310 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Tuple, Union
import torch
from einops import rearrange, repeat
from fla.ops.rotary import apply_rotary
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[x[..., :ro_dim] * cos +
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
out = apply_rotary(
x,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=interleaved,
inplace=inplace,
)
if isinstance(seqlen_offsets, int):
# Can't save int with save_for_backward
ctx.save_for_backward(cos, sin, cu_seqlens)
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
ctx.inplace = inplace
ctx.max_seqlen = max_seqlen
return out if not inplace else x
@staticmethod
def backward(ctx, do):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
else:
cos, sin, cu_seqlens = ctx.saved_tensors
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if not ctx.interleaved and not ctx.inplace:
do = do.clone()
dx = apply_rotary(
do,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=ctx.interleaved,
inplace=ctx.inplace,
conjugate=True,
)
return dx, None, None, None, None, None, None, None
def apply_rotary_emb(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmb.apply(
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
)
# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(
self,
dim: int,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
"""
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
scale = (
(torch.arange(0, dim, 2, device=device,
dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device,
dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype,
device=self.scale.device)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(
device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
seqlen_offset: Union[int, torch.Tensor] = 0,
max_seqlen: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
Apply rotary embedding *inplace* to qkv and / or kv.
"""
seqlen = q.shape[1]
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype)
elif isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=q.device, dtype=q.dtype)
if self.scale is None:
q = apply_rotary_emb_func(
q,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
k = apply_rotary_emb_func(
k,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
q = apply_rotary_emb_func(
q,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
k = apply_rotary_emb_func(
k,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
return q, k

18
finetune/lora/v6/fla/ops/__init__.py vendored Normal file
View File

@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
from .based import fused_chunk_based, parallel_based
from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
from .retention import (chunk_retention, fused_chunk_retention,
fused_recurrent_retention, parallel_retention)
__all__ = [
'fused_chunk_based',
'parallel_based',
'chunk_gla',
'fused_chunk_gla',
'fused_recurrent_gla',
'chunk_retention',
'fused_chunk_retention',
'fused_recurrent_retention',
'parallel_retention'
]

View File

@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_abc
from .chunk_gate import chunk_gated_abc
from .recurrent_fuse import fused_recurrent_gated_abc
__all__ = [
'chunk_abc',
'chunk_gated_abc',
'fused_recurrent_gated_abc'
]

1194
finetune/lora/v6/fla/ops/abc/chunk.py vendored Normal file

File diff suppressed because it is too large Load Diff

1287
finetune/lora/v6/fla/ops/abc/chunk_gate.py vendored Normal file

File diff suppressed because it is too large Load Diff

90
finetune/lora/v6/fla/ops/abc/naive.py vendored Normal file
View File

@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
from typing import Optional
import torch
def naive_recurrent_abc(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
s: torch.Tensor,
g: Optional[torch.Tensor] = None,
scale: Optional[int] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: Optional[bool] = False
) -> torch.Tensor:
dtype = q.dtype
# [batch_size, n_heads, seq_len, n_slots]
if g is None:
z = s.float().logcumsumexp(2)
g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
s = torch.exp(s - z)
q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
ok = torch.zeros_like(s)
if scale is None:
scale = q.shape[-1] ** -0.5
final_state = None
if initial_state is not None:
hk += initial_state[0]
for i in range(T):
q_i = q[:, :, i] * scale
k_i = k[:, :, i]
v_i = s[:, :, i]
g_i = g[:, :, i].exp()
hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
qv = ok.softmax(-1)
hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
ov = torch.zeros_like(v)
if initial_state is not None:
hv += initial_state[1]
for i in range(T):
q_i = qv[:, :, i]
k_i = s[:, :, i]
v_i = v[:, :, i]
g_i = g[:, :, i].exp()
hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
if output_final_state:
final_state = (hk, hv)
return ov.to(dtype), final_state
def naive_cumsum_abc(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
s: torch.Tensor
) -> torch.Tensor:
"""
A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper.
This is just for demonstration purposes, with no numerical stabilities guaranteed.
"""
dtype = q.dtype
q, k, v, s = map(lambda x: x.float(), (q, k, v, s))
scale = q.shape[-1] ** -0.5
# [batch_size, n_heads, seq_len, n_slots]
s = (s - s.max(2, True)[0]).exp()
z = s.cumsum(2)
# [batch_size, n_heads, seq_len, n_slots, d_head]
K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
# [batch_size, n_heads, seq_len, n_slots]
p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1)
# [batch_size, n_heads, seq_len, d_head]
o = torch.einsum('...m,...md->...d', p, V)
return o.to(dtype), None

View File

@ -0,0 +1,388 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Yu Zhang, Songlin Yang
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
@triton.jit
def fused_recurrent_gated_abc_fwd_kernel(
q,
k,
v,
gk,
gv,
o,
h0,
ht,
s_k_h,
s_v_h,
scale,
B: tl.constexpr,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
REVERSE: tl.constexpr,
USE_GK: tl.constexpr,
USE_GV: tl.constexpr,
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
mask_bk = (i_k * BK + tl.arange(0, BK)) < K
mask_bv = (i_v * BV + tl.arange(0, BV)) < V
h = tl.zeros([BV, BK], dtype=tl.float32)
mask_kv = mask_bk[None, :] & mask_bv[:, None]
if USE_INITIAL_STATE:
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
if USE_GK:
b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
h = h * b_gk[None, :]
if USE_GV:
b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
h = h * b_gv[:, None]
h += b_k[None, :] * b_v[:, None]
b_o = h * b_q[None, :]
b_o = tl.sum(b_o, axis=1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += -K if REVERSE else K
p_k += -K if REVERSE else K
p_o += -V if REVERSE else V
p_v += -V if REVERSE else V
if USE_GK:
p_gk += -K if REVERSE else K
if USE_GV:
p_gv += -V if REVERSE else V
if STORE_FINAL_STATE:
p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)
@triton.jit
def fused_recurrent_gated_abc_bwd_kernel(
q,
k,
v,
gk,
gv,
do,
dq,
dk,
dv,
h0,
s_k_h,
s_v_h,
scale,
B: tl.constexpr,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
REVERSE: tl.constexpr,
USE_GK: tl.constexpr,
USE_GV: tl.constexpr,
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
mask_bk = i_k * BK + tl.arange(0, BK) < K
mask_bv = i_v * BV + tl.arange(0, BV) < V
mask_kv = mask_bk[:, None] & mask_bv[None, :]
h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
if USE_GK:
b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
h = h * b_gk[:, None]
if USE_GV:
b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
h = h * b_gv[None, :]
h += b_k[:, None] * b_v[None, :]
b_dq = tl.sum(h * b_do[None, :], axis=1) * scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)
p_k += -K if REVERSE else K
p_v += -V if REVERSE else V
p_q += -K if REVERSE else K
p_do += -V if REVERSE else V
p_dq += -K if REVERSE else K
if USE_GK:
p_gk += -K if REVERSE else K
if USE_GV:
p_gv += -V if REVERSE else V
# sync threads
tl.debug_barrier()
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
b_dh += b_q[:, None] * b_do[None, :]
b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
if USE_GK:
b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
b_dh *= b_gk[:, None]
if USE_GV:
b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
b_dh *= b_gv[None, :]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
p_q += K if REVERSE else -K
p_k += K if REVERSE else -K
p_v += V if REVERSE else -V
p_do += V if REVERSE else -V
p_dk += K if REVERSE else -K
p_dv += V if REVERSE else -V
if USE_GK:
p_gk += K if REVERSE else -K
if USE_GV:
p_gv += V if REVERSE else -V
class FusedRecurrentGatedABCFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, s, g, scale=None, initial_state=None, output_final_state=False, reverse=False):
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
# default scale
if scale is None:
scale = K ** -0.5
BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)
NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
num_stages = 1
num_warps = 1
g = g.float().exp()
final_state = (None, None)
if output_final_state:
final_state = (q.new_empty(B, H, K, M), q.new_empty(B, H, M, V))
ok = q.new_empty(NK, B, H, T, M, dtype=torch.float)
gk, gv = None, g
grid = (NM, NK, B * H)
fused_recurrent_gated_abc_fwd_kernel[grid](
q, k, s, gk, gv, ok, initial_state[0], final_state[0],
k.stride(1),
s.stride(1),
scale=scale,
B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
USE_INITIAL_STATE=initial_state[0] is not None,
STORE_FINAL_STATE=final_state[0] is not None,
USE_GK=False,
USE_GV=True,
REVERSE=reverse,
num_warps=num_warps,
num_stages=num_stages
)
ok = ok.sum(0)
qv = ok.softmax(-1, dtype=torch.float)
ov = q.new_empty(NM, B, H, T, V, dtype=torch.float)
gk, gv = g, None
grid = (NV, NM, B * H)
fused_recurrent_gated_abc_fwd_kernel[grid](
qv, s, v, gk, gv, ov, initial_state[1], final_state[1],
s.stride(1),
v.stride(1),
scale=1.,
B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
USE_INITIAL_STATE=initial_state[0] is not None,
STORE_FINAL_STATE=final_state[0] is not None,
USE_GK=True,
USE_GV=False,
REVERSE=reverse,
num_warps=num_warps,
num_stages=num_stages
)
ov = ov.sum(0)
ctx.save_for_backward(q, k, v, s, g, qv, *initial_state, ok)
ctx.scale = scale
ctx.reverse = reverse
# we do not need the gradient of the final state from the next chunk
# similiar to Trunctated BPTT
if final_state is not None:
final_state = tuple(i.detach() for i in final_state)
return ov.to(q.dtype), final_state
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, dht=None):
q, k, v, s, g, qv, *initial_state, ok = ctx.saved_tensors
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
V = v.shape[-1]
scale = ctx.scale
BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)
NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
num_stages = 1
num_warps = 1
dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
dv = q.new_empty(NM, B, H, T, V, dtype=torch.float)
gk, gv = g, None
grid = (NV, NM, B * H)
fused_recurrent_gated_abc_bwd_kernel[grid](
qv, s, v, gk, gv, do, dqv, dsv, dv, initial_state[1],
s.stride(1),
v.stride(1),
scale=1.,
B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state[1] is not None,
REVERSE=ctx.reverse,
USE_GK=gk is not None,
USE_GV=gv is not None
)
dqv = dqv.sum(0)
dsv = dsv.sum(0)
dv = dv.sum(0)
dgk = dqv * qv.float() - dsv * s.float()
dgk_cumsum = dgk.cumsum(-2)
dgk = dgk + dgk_cumsum[:, :, -1, None] - dgk_cumsum
dok = qv * (dqv - (qv * dqv).sum(-1, True))
dq = q.new_empty(NM, B, H, T, K, dtype=torch.float)
dk = q.new_empty(NM, B, H, T, K, dtype=torch.float)
dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float)
gk, gv = None, g
grid = (NM, NK, B * H)
fused_recurrent_gated_abc_bwd_kernel[grid](
q, k, s, gk, gv, dok, dq, dk, dsk, initial_state[0],
q.stride(1),
s.stride(1),
scale=scale,
B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state[0] is not None,
REVERSE=ctx.reverse,
USE_GK=gk is not None,
USE_GV=gv is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dsk = dsk.sum(0)
dgv = dok.float() * ok.float() - dsk * s.float()
dgv_cumsum = dgv.cumsum(-2)
dgv = dgv + dgv_cumsum[:, :, -1, None] - dgv_cumsum
ds = dsk.add_(dsv)
dg = dgk.add_(dgv)
return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, None, None, None
def fused_recurrent_gated_abc(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
s: torch.Tensor,
g: Optional[torch.Tensor] = None,
scale: Optional[int] = None,
initial_state: Optional[Tuple[torch.Tensor]] = None,
output_final_state: Optional[bool] = False
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `(B, H, T, K)`
k (torch.Tensor):
keys of shape `(B, H, T, K)`
v (torch.Tensor):
values of shape `(B, H, T, V)`
g (torch.Tensor):
Forget gates of shape `(B, H, T, M)` applied to keys.
If not provided, this function is equivalent to vanilla ABC.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[Tuple[torch.Tensor]]):
Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`.
"""
if initial_state is not None:
initial_state = tuple(i.detach() for i in initial_state)
if g is None:
# TODO: this 3 steps took huge amount of time, ought to be optimized
z = s.float().logcumsumexp(2)
g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
s = torch.exp(s - z).to(k.dtype)
if scale is None:
scale = q.shape[-1] ** -0.5
ov, final_state = FusedRecurrentGatedABCFunction.apply(q, k, v, s, g, scale, initial_state, output_final_state)
return ov, final_state

View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from .chunk_fuse import fused_chunk_based
from .parallel import parallel_based
__all__ = [
'fused_chunk_based',
'parallel_based'
]

View File

@ -0,0 +1,410 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.jit
def fused_chunk_based_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
z, # normalizer [B, H, L, 1]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# [BT, BT]
m_s = o_i[:, None] >= o_i[None, :]
# [BV], zero-order taylor expansion
b_h_0o = tl.zeros([BV], dtype=tl.float32)
# [BK, BV], first-order taylor expansion
b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
# [BK, BK, BV] second-order taylor expansion
b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
k_1o = tl.zeros([1, BK], dtype=tl.float32)
k_0o = 0
for i in range(0, tl.cdiv(T, BT)):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK*BK, BT]
b_k_2o = b_k[:, None, :] * b_k[None, :, :]
b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_z = tl.zeros([BT], dtype=tl.float32)
# interchunk
# zero-order
b_o += b_h_0o
b_z += k_0o
# first-order
b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
b_z += tl.sum(b_q * k_1o, axis=1)
# second-order
b_q_2o = b_q[:, :, None] * b_q[:, None, :]
b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
# update running statistics
k_1o += tl.sum(b_k, axis=1)[None, :]
k_2o += tl.sum(b_k_2o, axis=1)[None, :]
k_0o += BT
# intrachunk
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = 1 + b_s + 0.5 * b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_z += tl.sum(b_s, axis=1)
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
# [TB, BV]
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_z, b_z.to(p_z.dtype.element_ty),
mask=(i * BT + tl.arange(0, BT)) < T)
# update hidden state
# [BK, BV]
b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
p_z += BT
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_chunk_based_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
do, # gradient of output [B, H, L, D_head_V]
dz, # gradient of normalizer [B, H, L]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
# [BV], zero-order taylor expansion
# b_h_0o = tl.zeros([BV], dtype=tl.float32)
# [BK, BV], first-order taylor expansion
b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
# [BK, BK, BV] second-order taylor expansion
b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
k_1o = tl.zeros([1, BK], dtype=tl.float32)
k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(
k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(
v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h,
(T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
# load tensors
# [BT, BK]
b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# inter-chunk
b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
if i_v == 0:
b_dq += b_dz[:, None] * k_1o
b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
if i_v == 0:
b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
b_dq *= scale
# intra-chunk
# [BT, BT]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[:, None]
b_ds = tl.where(m_s, b_ds, 0) * scale
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
# store
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# update hidden state
# [BT, BK*BK]
b_k_2o = b_k[:, :, None] * b_k[:, None, :]
b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
# [BV, BK*BK]
b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
# [BV, BK]
b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
if i_v == 0:
# update running statistics
k_1o += tl.sum(b_k, axis=0)[None, :]
k_2o += tl.sum(b_k_2o, axis=0)[None, :]
tl.debug_barrier()
b_h_1o = None
b_h_2o = None
# [BK, BV], first-order taylor expansion
b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
# [BK, BK, BV] second-order taylor expansion
b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
b_dh_0o = tl.zeros([BV], dtype=tl.float32)
m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
dq_1o = tl.zeros([1, BK], dtype=tl.float32)
dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(
k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(
v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))
p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_q = (b_q * scale).to(b_k.dtype)
# intra chunk
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
if i_v == 0:
b_ds += b_dz[None, :]
b_ds = tl.where(m_s, b_ds, 0)
b_s = tl.dot(b_k, b_q, allow_tf32=False)
b_s2 = 1 + b_s + 0.5 * b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_s2 = tl.where(m_s, b_s2, 0)
b_ds *= (1+b_s)
b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
# inter chunk
b_k_2o = b_k[:, :, None] * b_k[:, None, :]
b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
b_dv += b_dh_0o
b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
if i_v == 0:
b_dk += dq_1o
b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype),
tl.trans(b_v), allow_tf32=False)
if i_v == 0:
b_dk_2o += dq_2o
b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
b_k_fp32 = tl.trans(b_k.to(tl.float32))
b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
b_dk += tl.trans(b_dk2)
# hidden state update
b_dh_0o += tl.sum(b_do, axis=0)
b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
b_q_2o = b_q[None, :, :] * b_q[:, None, :]
b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
if i_v == 0:
dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
class FusedChunkBasedFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, scale=1):
batch_size, n_heads, seq_len, d_head_qk = q.shape
# assert d_head_qk == 16, "currently we do not support feature dim other than 16"
d_head_v = v.shape[-1]
scale = scale
BT = 16
BK, BV = min(d_head_qk, 16), min(d_head_v, 32)
BK, BV = max(BK, 16), max(BV, 16)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_warps = 4
# the norm of o might explode, so we need to use float32 here
o = q.new_empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=torch.float32)
z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_based_fwd_kernel[grid](
q, k, v, o, z,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
)
o = o.sum(0)
z = z.sum(0)
ctx.save_for_backward(q, k, v)
ctx.scale = scale
return o.to(q.dtype), z.to(z.dtype)
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = ctx.scale
BT = 16
BK, BV = min(d_head_qk, 16), min(d_head_v, 32)
BK, BV = max(BK, 16), max(BV, 16)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 4
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_based_bwd_kernel[grid](
q, k, v, do, dz, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
triton_fused_chunk_based = FusedChunkBasedFunction.apply
def fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):
assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
if use_scale:
scale = q.shape[-1] ** -0.5
else:
scale = 1
o, z = triton_fused_chunk_based(q, k, v, scale)
if use_normalize:
o = o / (z[..., None] + 1e-6)
else:
o = o
return o.to(q.dtype)

132
finetune/lora/v6/fla/ops/based/naive.py vendored Normal file
View File

@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
from fla.ops.based.chunk_fuse import fused_chunk_based
from fla.ops.based.parallel import parallel_based
def naive_parallel_based(q, k, v, use_scale=True, use_norm=True):
if use_scale:
q = q * (q.shape[-1] ** -0.5)
attn = q @ k.transpose(-2, -1)
attn = 1 + attn + 1/2 * (attn ** 2)
attn.masked_fill_(~torch.tril(torch.ones(
q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
o = attn @ v
if use_norm:
z = attn.sum(-1)
return o / (z[..., None] + 1e-6)
else:
return o
def naive_chunk_based(q, k, v, chunk_size=256):
q = q * (q.shape[-1] ** -0.5)
# compute normalizer.
k_cumsum = torch.cumsum(k, dim=-2)
kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
# first
z = (q * k_cumsum).sum(-1)
# second order
z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
# zero-th order
z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
# compute o
# constant term
_o = v.cumsum(-2)
q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
intra_chunk_attn = q @ k.transpose(-2, -1)
intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
intra_chunk_attn.masked_fill_(
~torch.tril(
torch.ones(chunk_size, chunk_size,
dtype=torch.bool, device=q.device),
), 0)
o = intra_chunk_attn @ v
# quadractic term
kv = torch.einsum(
'b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
kv = kv.cumsum(2)
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
# linear term
kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
kv = kv.cumsum(2)
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
o = rearrange(o, 'b h n c d -> b h (n c) d')
o = o + _o
return o / (z[..., None] + 1e-6)
if __name__ == "__main__":
B = 4
H = 4
L = 128
# D = 15
dtype = torch.float32
q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True)
do = torch.randn_like(v).cuda()
ref = naive_parallel_based(q, k, v, True, True)
ref.backward(do, retain_graph=True)
ref_dq, q.grad = q.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dv, v.grad = v.grad.clone(), None
# tri = naive_chunk_based(q, k, v)
# tri.backward(do, retain_graph=True)
# tri_dq, q.grad = q.grad.clone(), None
# tri_dk, k.grad = k.grad.clone(), None
# tri_dv, v.grad = v.grad.clone(), None
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
tri = fused_chunk_based(q, k, v, True, True)
tri.backward(do, retain_graph=True)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
print((ref-tri).abs().max())
print((ref_dq-tri_dq).abs().max())
print((ref_dk-tri_dk).abs().max())
print((ref_dv-tri_dv).abs().max())
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
tri = parallel_based(q, k, v, True, True)
tri.backward(do, retain_graph=True)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
print((ref-tri).abs().max())
print((ref_dq-tri_dq).abs().max())
print((ref_dk-tri_dk).abs().max())
print((ref_dv-tri_dv).abs().max())
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()

View File

@ -0,0 +1,388 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# Based: An Educational and Effective Sequence Mixer
# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
@triton.jit
def parallel_based_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
z, # normalizer [B, H, L]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q
BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
):
# i_c: chunk index. used for sequence parallelism
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
NV = tl.cdiv(DV, BV)
i_k = i_kv // (NV)
i_v = i_kv % (NV)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))
# [BQ, BD] block Q, in the shared memory throughout the whole kernel
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
b_o = tl.zeros([BTL, BV], dtype=tl.float32)
b_z = tl.zeros([BTL], dtype=tl.float32)
# Q block and K block have no overlap
# no need for mask, thereby saving flops
for _ in range(0, i_c * BTL, BTS):
# [BK, BTS]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BTS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
b_s = tl.dot(b_q, (b_k), allow_tf32=False)
b_s = 1 + b_s + 0.5 * b_s * b_s
b_z += tl.sum(b_s, axis=1)
# [BQ, BD]
b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
p_k = tl.advance(p_k, (0, BTS))
p_v = tl.advance(p_v, (BTS, 0))
# # rescale interchunk output
tl.debug_barrier()
o_q = tl.arange(0, BTL)
# # sync threads, easy for compiler to optimize
# tl.debug_barrier()
o_k = tl.arange(0, BTS)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
# Q block and K block have overlap. masks required
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
# [BK, BTS]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BTS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
m_s = o_q[:, None] >= o_k[None, :]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = 1 + b_s + 0.5 * b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_z += tl.sum(b_s, axis=1)
# [BTL, BV]
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
p_k = tl.advance(p_k, (0, BTS))
p_v = tl.advance(p_v, (BTS, 0))
o_k += BTS
p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_z, b_z.to(p_z.dtype.element_ty),
mask=((i_c * BTL + tl.arange(0, BTL)) < T))
@triton.jit
def _parallel_based_bwd_dq(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_q = (b_q * scale).to(b_q.dtype)
b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
(s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))
p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
for _ in range(0, i_c * BTL, BTS):
# [BTS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BTS]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[:, None]
else:
b_ds = b_ds
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
# [BQ, BD]
b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)
p_k = tl.advance(p_k, (BTS, 0))
p_v = tl.advance(p_v, (0, BTS))
b_dq *= scale
o_q = tl.arange(0, BTL)
o_k = tl.arange(0, BTS)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
(s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
# Q block and K block have overlap. masks required
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
# [BTS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BTS]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
m_s = o_q[:, None] >= o_k[None, :]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[:, None]
else:
b_ds = b_ds
b_ds = tl.where(m_s, b_ds, 0) * scale
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
# [BTL, BK]
b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype),
b_k, allow_tf32=False)
p_k = tl.advance(p_k, (BTS, 0))
p_v = tl.advance(p_v, (0, BTS))
o_k += BTS
p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
return
@triton.jit
def _parallel_based_bwd_dkv(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
# compute dk dv
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
(i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(
p_v, boundary_check=(0, 1))
b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
[BTL, BV], dtype=tl.float32)
for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \
scale # [BTL, BTS]
b_s2 = 1 + b_s + 0.5 * b_s * b_s
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
if i_v == 0:
b_ds += b_dz[None, :] * scale
else:
b_ds = b_ds
b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),
tl.trans(b_q), allow_tf32=False)
tl.debug_barrier()
o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
# [BK, BQ]
m_s = o_k[:, None] <= o_q[None, :]
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s2 = 1 + b_s + 0.5 * b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_s2 = tl.where(m_s, b_s2, 0)
b_ds = tl.dot(b_v, b_do, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[None, :]
else:
b_ds = b_ds
b_ds = tl.where(m_s, b_ds, 0) * scale
# [BK, BD]
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),
tl.trans(b_q), allow_tf32=False)
o_q += BTS
p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,
(T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,
(T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
return
@triton.jit
def parallel_based_bwd_kernel(
q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
NV = tl.cdiv(DV, BV)
i_k = i_kv // (NV)
i_v = i_kv % (NV)
i_h = i_bh % H
_parallel_based_bwd_dq(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV
)
tl.debug_barrier()
_parallel_based_bwd_dkv(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV
)
class ParallelBasedFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, scale):
BTL, BTS = 128, 32
assert BTL % BTS == 0
# assert q.shape[-1] % 16 == 0
BK = min(128, triton.next_power_of_2(k.shape[-1]))
BV = min(128, triton.next_power_of_2(v.shape[-1]))
BK, BV = max(BK, 16), max(BV, 16)
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
num_stages = 2
num_warps = 4
NK = triton.cdiv(d_head_qk, BK)
NV = triton.cdiv(d_head_v, BV)
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
assert NK == 1, "will encounter some synchronization issue if not."
o = torch.empty(NK, batch_size, n_heads, seq_len,
d_head_v, device=q.device)
z = torch.empty(NK, batch_size, n_heads, seq_len,
device=q.device)
parallel_based_fwd_kernel[grid](
q, k, v, o, z,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
num_warps=num_warps,
num_stages=num_stages
)
ctx.save_for_backward(q, k, v)
ctx.scale = scale
return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
scale = ctx.scale
BTL, BTS = 64, 32
assert BTL % BTS == 0
BK = min(128, triton.next_power_of_2(k.shape[-1]))
BV = min(128, triton.next_power_of_2(v.shape[-1]))
BK, BV = max(BK, 16), max(BV, 16)
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
num_stages = 2
num_warps = 4
NK = triton.cdiv(d_head_qk, BK)
NV = triton.cdiv(d_head_v, BV)
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
assert NK == 1, "will encounter some synchronization issue if not"
dq = torch.empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=q.dtype, device=q.device)
dk = torch.empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=q.dtype, device=q.device)
dv = torch.empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=q.dtype, device=q.device)
parallel_based_bwd_kernel[grid](
q, k, v, do, dz, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
num_warps=num_warps,
num_stages=num_stages
)
return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
triton_parallel_based = ParallelBasedFunction.apply
def parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):
assert q.shape[-1] <= 128, "only support feature dim up to 128"
if use_scale:
scale = q.shape[-1] ** -0.5
else:
scale = 1
o, z = triton_parallel_based(q, k, v, scale)
if return_both:
return o, z
if use_normalize:
o = o / (z[..., None] + 1e-6)
else:
o = o
return o.to(q.dtype)

View File

@ -0,0 +1,4 @@
- Delta Rule
The implementation of delta rule described in https://arxiv.org/abs/2102.11174

View File

@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
from .chunk_fuse import fused_chunk_delta_rule
from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule
from .chunk import chunk_delta_rule
__all__ = [
'fused_chunk_delta_rule',
'fused_recurrent_linear_attn_delta_rule',
'chunk_delta_rule'
]

View File

@ -0,0 +1,544 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
import torch
import triton
import triton.language as tl
from fla.ops.utils import contiguous
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.ops.delta_rule.wy_fast import fwd_recompute_w_u, fwd_prepare_wy_repr, bwd_prepare_wy_repr
from fla.ops.delta_rule.chunk_fuse import fused_chunk_delta_rule_fwd, fused_chunk_delta_rule_bwd
# from fla.ops.delta_rule.utils import bwd_prepare_wy_repr
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_prepare_dv_kernel(
q,
k,
do,
dv,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
scale,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_k.dtype)
b_A += tl.dot(b_k, b_q, allow_tf32=False)
b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A , 0).to(do.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.dot(b_A, b_do, allow_tf32=False)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
def fwd_prepare_dv(q, k, do, BT):
dv = torch.empty_like(do)
B, H, T, K, V = *k.shape, do.shape[-1]
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
fwd_prepare_dv_kernel[(NT, B*H)](
q, k, do, dv,
k.stride(1), k.stride(2), k.stride(3),
do.stride(1), do.stride(2), do.stride(3),
T, K, V, K**-0.5, BT, BK, BV
)
return dv
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_delta_rule_fwd_kernel_h(
k,
v,
d,
v_new,
h,
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32)
# since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
for i_c in range(tl.cdiv(BT, BC)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BK]
b_d = tl.load(p_d, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)
# [BK, BV]
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
b_h += b_h_cumsum
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_linear_attn_fwd_kernel_o(
q,
k,
v,
h,
o,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_s = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_o += tl.dot(b_q, b_h, allow_tf32=False)
b_s += tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False))
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_delta_rule_bwd_kernel_dhu(
q,
k,
d,
do,
dh,
dv,
dv2,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_d = tl.load(p_d, boundary_check=(0, 1))
# [BT, V]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
# b_s = tl.dot(b_k, b_q, allow_tf32=False)
# b_s = tl.where(m_s, b_s, 0)
# b_dv = tl.dot(b_s.to(b_do.dtype), b_do, allow_tf32=False) + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BK, BV]
b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False)
b_dh += b_dh_tmp
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_delta_rule_bwd_kernel_dqkw(
q,
k,
v,
w,
h,
do,
dh,
dq,
dk,
dv,
dw,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
o_i = tl.arange(0, BT)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BT]
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dw += tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
# [BT, BT]
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
# [BT, BK]
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state):
B, H, T, K, V = *k.shape, u.shape[-1]
BK = triton.next_power_of_2(K)
assert BK <= 256, "current kernel does not support head dimension larger than 256."
BV = 16 if BK > 128 else 32
BV = 64 if BK <= 64 else BV
BC = 16 if BK > 128 else 32
BC = 64 if BK <= 64 else BC
BC = min(BT, BC)
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
h = k.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
v_new = torch.empty_like(u)
chunk_delta_rule_fwd_kernel_h[grid](
k, u, w, v_new, h, initial_state, final_state,
k.stride(1), k.stride(2), k.stride(3),
u.stride(1), u.stride(2), u.stride(3),
h.stride(1), h.stride(2),
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None,
)
return h, v_new
def chunk_bwd_dhu_fn(q, k, w, do, dv, BT):
B, H, T, K, V = *q.shape, do.shape[-1]
BK = triton.next_power_of_2(K)
assert BK <= 256, "current kernel does not support head dimension being larger than 256."
BV = 16 if BK > 128 else 32
BV = 64 if BK <= 64 else BV
BC = 16 if BK > 128 else 32
BC = 64 if BK <= 64 else BC
BC = min(BT, BC)
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
dh = q.new_empty(B, H, NT * K, V)
# dv_new = torch.empty_like(do)
grid = (NK, NV, B * H)
dv2 = torch.empty_like(dv)
chunk_delta_rule_bwd_kernel_dhu[grid](
q, k, w, do, dh, dv, dv2,
q.stride(1), q.stride(2), q.stride(3),
do.stride(1), do.stride(2), do.stride(3),
dh.stride(1), dh.stride(2),
K**-0.5,
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
)
return dh, dv2
def chunk_fwd_o_fn(q, k, v_new, h, BT):
B, H, T, K, V = *q.shape, v_new.shape[-1]
BK = triton.next_power_of_2(K)
o = torch.empty_like(v_new)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(K), 64)
NV = triton.cdiv(V, BV)
NT = triton.cdiv(T, BT)
grid = (NV, NT, B * H)
chunk_linear_attn_fwd_kernel_o[grid](
q, k, v_new, h, o,
q.stride(1), q.stride(2), q.stride(3),
v_new.stride(1), v_new.stride(2), v_new.stride(3),
h.stride(1), h.stride(2),
scale=K**-0.5,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
)
return o
def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT):
B, H, T, K, V = *q.shape, v_new.shape[-1]
BK = triton.next_power_of_2(K)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
NV = triton.cdiv(V, BV)
NT = triton.cdiv(T, BT)
grid = (NV, NT, B * H)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dw = torch.empty_like(w)
chunk_delta_rule_bwd_kernel_dqkw[grid](
q, k, v_new, w, h, do, dh, dq, dk, du, dw,
q.stride(1), q.stride(2), q.stride(3),
v_new.stride(1), v_new.stride(2), v_new.stride(3),
dh.stride(1), dh.stride(2),
scale = K ** -0.5,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
)
return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype)
class ChunkDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@contiguous
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):
### obtain WY representation. u is actually the new v.
w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
# ### forward_h
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)
## obtain output
o = chunk_fwd_o_fn(q, k, v_new, h, BT)
# save memory
if checkpoint_level == 1:
h, v_new = None, None
ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state)
ctx.BT = BT
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_ht=None):
q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
scale = q.shape[-1] ** -0.5
BT = ctx.BT
w, u = fwd_recompute_w_u(k, v, beta, A, BT)
# checkpont_level=1, recomputation.
if h is None:
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None)
dv = fwd_prepare_dv(q, k, do, BT)
dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)
dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)
dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT)
dk.add_(dk2)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None
def chunk_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
BT: int,
initial_state: torch.Tensor = None,
output_final_state: bool = False
):
assert q.dtype == k.dtype == v.dtype
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
return o, final_state

View File

@ -0,0 +1,419 @@
# -*- coding: utf-8 -*-
from typing import Tuple
import torch
import triton
import triton.language as tl
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8)
],
key=["BT", "BK"],
)
@triton.jit
def fused_chunk_delta_rule_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_K]
v, # value [B, H, L, D_head_V]
v_new,
d, # decay [B, H, L, D_head_K]
o, # output [B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# [BT, BT]
m_s = o_i[:, None] >= o_i[None, :]
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i in range(0, tl.cdiv(T, BT)):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_d = tl.load(p_d, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_k.dtype)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
# [BT, BV]
b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)
b_v = b_v - b_v_prime
tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)
if CHECK and i == 0:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
else:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_v_new = tl.advance(p_v_new, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
p_d = tl.advance(p_d, (BT, 0))
if STORE_FINAL_STATE:
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fused_chunk_delta_rule_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
d, # decay [B, H, L, D_head_K]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
dd, # gradient of decay [NV, B, H, L, D_head_K]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# first reverse
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
m_s = o_i[:, None] <= o_i[None, :]
for i in range(1, tl.cdiv(T, BT) + 1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
# [DK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, DV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
# [BT, BT]
b_s = tl.dot(b_k, b_q, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
# [BT, DK]
b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
# [BT, DV]
b_dv = tl.dot(b_s, b_do, allow_tf32=False)
b_d = tl.load(p_d, boundary_check=(0, 1))
if CHECK and i == 1:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
else:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# sync threads
b_h = None
tl.debug_barrier()
m_s = o_i[:, None] >= o_i[None, :]
# [BV, BK]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
NT = tl.cdiv(T, BT)
for i in range(0, NT):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [DV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, DV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0)
# [BT, DK]
b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
# [DV, DK]
if CHECK and i == 0:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
else:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
if i < (NT - 1):
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
((i+1) * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))
def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BT = BT
# ctx.BT = BT
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
assert NK == 1, 'NK should be 1'
o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
else:
final_state = None
CHECK = True
# if version.parse(triton.__version__) < version.parse('2.2.0'):
# import warnings
# warnings.warn(
# "Triton<2.2.0 detected for running this kernel, "
# "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
# "that lead to significant precision loss. "
# "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
# "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
# )
# CHECK = True
grid = (NV, NK, batch_size * n_heads)
v_new = torch.empty_like(v)
fused_chunk_delta_rule_fwd_kernel[grid](
q, k, v, v_new, d, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
CHECK=CHECK,
)
return o, v_new, CHECK, final_state
def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
assert NK == 1
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_delta_rule_bwd_kernel[grid](
q, k, v, d, do, dq, dk, dv, dd, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
CHECK=CHECK,
# num_warps=num_warps,
# num_stages=num_stages
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
dd = dd.sum(0)
dd[:, :, 0:BT] = 0
return dq, dk, dv, dd
class FusedChunkDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0):
# lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory.
assert checkpoint_level in [0, 1]
k_origin = k
# k = _l2_norm_fwd(k_origin)
k = k
d, v_new = fwd_prepare_wy_repr(k, v, beta, BT)
o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state)
if checkpoint_level == 1:
d, v_new = None, None
ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state)
ctx.CHECK = CHECK
ctx.chunk_size = BT
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_final_state=None):
q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors
chunk_size = ctx.chunk_size
k = k_origin
# k = _l2_norm_fwd(k_origin)
if d is None:
d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state)
dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size)
dk.add_(dk2)
# dk = _l2_norm_bwd(k_origin, dk)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None
def fused_chunk_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
BT: int,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
return o, final_state
def delta_rule_recurrence(q, k, v, beta):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)
k = torch.nn.functional.normalize(k, p=2, dim=-1)
for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
return o
if __name__ == "__main__":
import torch.nn.functional as F
seq_len = 128
b = 2
h = 4
q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1)
beta = torch.rand(b, h, seq_len).sigmoid()
q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta))
do = torch.rand_like(v)
o2 = delta_rule_recurrence(q, k, v.clone(), beta)
o2.backward(do, retain_graph=True)
q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad
q.grad = k.grad = v.grad = beta.grad = None
o, _ = fused_chunk_delta_rule(q, k, v, beta, 32)
o.backward(do, retain_graph=True)
q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad
q.grad = k.grad = v.grad = beta.grad = None
print((o - o2).abs().max())
print((q_grad - q_grad2).abs().max())
print((k_grad - k_grad2).abs().max())
print((v_grad - v_grad2).abs().max())
print((beta_grad - beta_grad2).abs().max())

View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
def delta_rule_recurrence(q, k, v, beta):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)
for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
return o
def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
q = q * (d_k ** -0.5)
v = v * beta[..., None]
k_beta = k * beta[..., None]
assert l % chunk_size == 0
# note that diagonal is masked.
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
for i in range(1, chunk_size):
attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
# u
k_cumsum = attn @ v
# w
k_cumdecay = attn @ k_beta
v = k_cumsum
S = k.new_zeros(b, h, d_k, d_v)
o = torch.zeros_like(v)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
for i in range(0, l // chunk_size):
q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
v_prime = k_cumdecay[:, :, i] @ S
v_new = v_i - v_prime
o_inter = q_i @ S
o[:, :, i] = o_inter + attn @ v_new
# chunk state update
S = S + k_i.transpose(-1, -2) @ v_new
return rearrange(o, 'b h n c d -> b h (n c) d')
if __name__ == '__main__':
B = 2
H = 4
L = 256
DK = 128
DV = 128
q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True)
k = (torch.randn(B, H, L, DK)).cuda()
k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True)
v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True)
beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True)
o = delta_rule_recurrence(q, k, v, beta)
do = torch.randn(B, H, L, DV).cuda()
o.backward(do, retain_graph=True)
q_grad, q.grad = q.grad, None
k_grad, k.grad = k.grad, None
v_grad, v.grad = v.grad, None
beta_grad, beta.grad = beta.grad, None
o2 = delta_rule_chunkwise(q, k, v, beta)
o2.backward(do)
assert torch.allclose(o, o2, atol=1e-4), breakpoint()
assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint()
assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint()
assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint()
assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint()
print("All passed!")

View File

@ -0,0 +1,312 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.jit
def fused_recurrent_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V].
beta, # beta [B, H, L]
o, # output [B, H, L, D_head_V]
initial_state,
final_state, # final hidden state [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_beta = beta + i_bh * T
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
mask_kv = mask_bk[None, :] & mask_bv[:, None]
h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_v_minus = tl.sum(h * _k[None, :], axis=1)
_v -= _v_minus
_beta = tl.load(p_beta).to(tl.float32)
# in-place overwrite
tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv)
_v *= _beta
h += _k[None, :] * _v[:, None]
_o = h * _q[None, :]
_o = tl.sum(_o, axis=1)
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += DK
p_k += DK
p_o += DV
p_v += DV
p_beta += 1
if STORE_FINAL_STATE:
p_final_s = final_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
beta, # beta [B, H, L]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
dbeta, # gradient of beta [B, H, L]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_beta = beta + i_bh * T + T - 1
p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
BK + tl.arange(0, BK) + (T - 1) * DK
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
BV + tl.arange(0, BV) + (T - 1) * DV
d_h = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_beta = tl.load(p_beta).to(tl.float32)
d_h += _q[:, None] * _do[None, :]
d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1)
d_v = tl.sum(d_h * _k[:, None], axis=0)
d_beta = tl.sum(d_v * _v)
d_v = d_v * _beta
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))
d_h -= _k[:, None] * d_v[None, :]
p_do -= DV
p_q -= DK
p_k -= DK
p_v -= DV
p_dk -= DK
p_dv -= DV
p_dbeta -= 1
p_beta -= 1
tl.debug_barrier()
h = tl.zeros([BK, BV], dtype=tl.float32)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_beta = beta + i_bh * T
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK
if USE_INITIAL_STATE:
mask_kv = mask_bk[:, None] & mask_bv[None, :]
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[:, None]) * \
DV + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for i in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_beta = tl.load(p_beta).to(tl.float32)
_v *= _beta
h += _k[:, None] * _v[None, :]
_d_q = h * _do[None, :]
d_q = tl.sum(_d_q, axis=1) * scale
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
if i < T - 1:
d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)
d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)
d_k -= tl.sum(d_v[None, :] * h, axis=1)
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
p_k += DK
p_do += DV
p_v += DV
p_dk += DK
p_dv += DV
p_dq += DK
p_beta += 1
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
assert NK == 1, "NK > 1 is not supported yet"
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
else:
final_state = None
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_fwd_kernel[grid](
q, k, v, beta, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, beta, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, d_final_state=None):
q, k, v, beta, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 1
num_warps = 2
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
dbeta = q.new_empty(NV, batch_size, n_heads, seq_len)
fused_recurrent_bwd_kernel[grid](
q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
dbeta = dbeta.sum(0)
return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None
def fused_recurrent_linear_attn_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
if beta is None:
beta = torch.ones_like(q[..., 0])
o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state)
return o, final_state

View File

@ -0,0 +1,297 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
from fla.ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
# o: cumprod
# o2: cumprodsum
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_prepare_wy_repr_kernel(
k,
v,
beta,
o,
o2,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
mask_bt = (tl.arange(0, BT) + i_t * BT) < T
mask_bk = tl.arange(0, BK) < K
mask_bv = tl.arange(0, BV) < V
mask_bk = mask_bk[None, :] & mask_bt[:, None]
mask_bv = mask_bv[None, :] & mask_bt[:, None]
# [BT, BK]
b_k = tl.load(p_k, mask=mask_bk, other=0)
# [BT,]
b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)
# [BT, BV]
b_v = tl.load(p_v, mask=mask_bv, other=0)
b_v = (b_v * b_beta[:, None]).to(b_v.dtype)
# [BT, BK]
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
# [BT, BT]
b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
for i in range(BT):
mask = tl.arange(0, BT) == i
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
b_A = tl.where(mask[:, None], b_a, b_A)
b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
b_A = b_A.to(b_k.dtype)
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
b_u = tl.dot(b_A, b_v, allow_tf32=False)
p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)
p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def bwd_prepare_wy_repr_kernel(
k, v, beta,
o, o2, do, do2,
dk, dv, dbeta,
NT, K, V, T,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
mask_bt = (tl.arange(0, BT) + i_t * BT) < T
mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]
mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]
b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)
b_beta = b_beta.to(tl.float32)
A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]
A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)
b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)
b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)
dA = tl.zeros([BT, BT], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
for i in range(BT-1, -1, -1):
mask = tl.arange(0, BT) == i
attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)
do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)
dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)
b_do = b_do - attn[:, None] * do_[None, :]
b_dv = b_dv - attn[:, None] * dv_[None, :]
tl.debug_barrier()
p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
b_v = tl.load(p_v, mask=mask_bv)
b_dk += b_do * b_beta[:, None]
b_dbeta = tl.sum(b_do * b_k, axis=1)
b_dbeta += tl.sum(b_dv * b_v, axis=1)
b_v = None
p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
b_o = tl.load(p_o, mask=mask_bk)
b_o2 = tl.load(p_o2, mask=mask_bv)
dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)
dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),
allow_tf32=False)
dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)
b_dv *= b_beta[:, None]
p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)
dA = dA * b_beta[:, None]
b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)
b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)
p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)
def fwd_prepare_wy_repr(k, v, beta, chunk_size):
B, H, T, K, V = *k.shape, v.shape[-1]
v_new = torch.empty_like(v)
o_cumdecay = torch.empty_like(k)
BT = chunk_size
NT = triton.cdiv(T, BT)
BK = triton.next_power_of_2(K)
BV = triton.next_power_of_2(V)
fwd_prepare_wy_repr_kernel[(NT, B*H)](
k, v, beta, o_cumdecay, v_new,
T, K, V, BT, BK, BV
)
return o_cumdecay, v_new
def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):
b, h, l, d_k = do.shape
d_v = v.shape[-1]
BK = triton.next_power_of_2(d_k)
BV = triton.next_power_of_2(d_v)
c = chunk_size
BK = d_k
NT = triton.cdiv(l, c)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dbeta = torch.zeros_like(beta)
bwd_prepare_wy_repr_kernel[(NT, b*h)](
k, v, beta,
o_cumdecay, v_new, do, do2,
dk, dv, dbeta,
NT, d_k, d_v, l, chunk_size, BK, BV
)
return dk, dv, dbeta
class WYRepresentationPrepration(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, k, v, beta, chunk_size):
o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
ctx.chunk_size = chunk_size
ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new)
return o_cumdecay, v_new
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, do2):
k, v, beta, o_cumdecay, v_new = ctx.saved_tensors
dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size)
return dk, dv, dbeta, None
prepare_wy_repr = WYRepresentationPrepration.apply
def naive(k, v, beta, chunk_size):
l_org = k.shape[2]
l_new = triton.next_power_of_2(l_org)
# pad k, v, beta
k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
# k = torch.nn.functional.normalize(k, dim=-1, p=2)
beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
k_beta = k * beta[..., None]
v = v * beta[..., None]
attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
attn = attn * beta[..., None]
x = attn @ v
o = torch.zeros_like(k)
o2 = torch.zeros_like(v)
o[..., 0, :] = k_beta[..., 0, :].clone()
o2[..., 0, :] = x[..., 0, :].clone()
for i in range(1, chunk_size):
o_i = (o[..., :i, :]).clone()
o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
o2_i = (o2[..., :i, :]).clone()
o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
seq_len = 2048
b = 4
h = 8
k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2)
v = torch.randn(b, h, seq_len, 256)
beta = torch.rand(b, h, seq_len).sigmoid()
require_grad = True
k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
do = torch.rand_like(k)
do2 = torch.rand_like(v)
print("Start warmup.")
o1, o2 = prepare_wy_repr(k, v, beta, 32)
# (o1 * do + o2 * do2).sum().backward()
o3, o4 = prepare_wy_repr2(k, v, beta, 32)
# (o1 * do + o2 * do2).sum().backward()
print((o1 - o3).abs().max())
print((o2 - o4).abs().max())
for i in range(30):
o1, o2 = prepare_wy_repr(k, v, beta, 32)
(o1 * do + o2 * do2).sum().backward()
o1, o2 = prepare_wy_repr2(k, v, beta, 32)
(o1 * do + o2 * do2).sum().backward()
print("Done warmup.")
import time
torch.cuda.synchronize()
start = time.time()
for i in range(200):
o1, o2 = prepare_wy_repr(k, v, beta, 64)
(o1 * do + o2 * do2).sum().backward()
torch.cuda.synchronize()
print(time.time() - start)
torch.cuda.synchronize()
start = time.time()
for i in range(200):
o1, o2 = prepare_wy_repr2(k, v, beta, 64)
(o1 * do + o2 * do2).sum().backward()
torch.cuda.synchronize()
print(time.time() - start)

View File

@ -0,0 +1,401 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
# o: cumprod
# o2: cumprodsum
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_prepare_wy_repr_kernel(
k,
v,
beta,
w,
u,
A,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
for i in range(1, BT):
mask = tl.arange(0, BT) == i
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
b_A = tl.where(mask[:, None], b_a, b_A)
b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))
b_A = b_A.to(k.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_recompute_w_u_kernel(
k,
v,
beta,
w,
u,
A,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def bwd_prepare_wy_repr_kernel(
k, v, beta, A,
dw, du,
dk, dv, dbeta,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
b_dbeta = tl.zeros([BT], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
b_du = tl.load(p_du, boundary_check=(0, 1))
b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)
b_dv = b_dv_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
# store
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
b_A2 = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_dw = tl.load(p_dw, boundary_check=(0, 1))
b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)
b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)
b_dk = b_dk_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
# store
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])
b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
tl.debug_barrier()
for i in range(BT-1, 0, -1):
mask = tl.arange(0, BT) == i
b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0)
b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1)
b_dA = tl.where(mask[:, None], b_da2, b_dA)
b_dA += b_da[None, :] * b_a[:, None]
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
tl.debug_barrier()
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dk = tl.load(p_dk, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
b_dk += b_dk_beta * b_beta[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty),boundary_check=(0,))
def fwd_prepare_wy_repr(k, v, beta, BT):
B, H, T, K, V = *k.shape, v.shape[-1]
u = torch.empty_like(v)
w = torch.empty_like(k)
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)
fwd_prepare_wy_repr_kernel[(NT, B*H)](
k, v, beta, w, u, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
)
return w, u, A
def fwd_recompute_w_u(k, v, beta, A, BT):
B, H, T, K, V = *k.shape, v.shape[-1]
u = torch.empty_like(v)
w = torch.empty_like(k)
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
fwd_recompute_w_u_kernel[(NT, B*H)](
k, v, beta, w, u, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
)
return w, u
def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):
B, H, T, K, V = *k.shape, v.shape[-1]
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
NT = triton.cdiv(T, BT)
dk = torch.empty_like(k)
dv = torch.empty_like(v).contiguous()
dbeta = torch.zeros_like(beta)
bwd_prepare_wy_repr_kernel[(NT, B*H)](
k, v, beta, A,
dw, du,
dk, dv, dbeta,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
)
return dk, dv, dbeta
class WYRepresentationPrepration(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, k, v, beta, chunk_size):
ctx.BT = chunk_size
w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)
ctx.save_for_backward(k, v, beta, A)
return w, u
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, dw, du):
k, v, beta, A = ctx.saved_tensors
BT = ctx.BT
dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT)
return dk, dv, dbeta, None
prepare_wy_repr = WYRepresentationPrepration.apply
def naive(k, v, beta, chunk_size):
l_org = k.shape[2]
l_new = triton.next_power_of_2(l_org)
# pad k, v, beta
k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
# k = torch.nn.functional.normalize(k, dim=-1, p=2)
beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
k_beta = k * beta[..., None]
v = v * beta[..., None]
attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
attn = attn * beta[..., None]
x = attn @ v
o = torch.zeros_like(k)
o2 = torch.zeros_like(v)
o[..., 0, :] = k_beta[..., 0, :].clone()
o2[..., 0, :] = x[..., 0, :].clone()
for i in range(1, chunk_size):
o_i = (o[..., :i, :]).clone()
o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
o2_i = (o2[..., :i, :]).clone()
o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
if __name__ == "__main__":
torch.set_default_dtype(torch.float32)
seq_len = 1024
b = 4
h = 4
k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)
v = torch.randn(b, h, seq_len, 128)
beta = torch.rand(b, h, seq_len).sigmoid()
# beta = torch.ones(b, h, seq_len)
require_grad = True
k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
do = torch.rand_like(k)
do2 = torch.rand_like(v)
o1, o2 = naive(k.clone(), v.clone(), beta.clone(), 64)
if require_grad:
o1.backward(do, retain_graph=True)
o2.backward(do2, retain_graph=True)
k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad
k.grad = v.grad = beta.grad = None
o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone())
print((o1-o3).abs().max())
print((o2-o4).abs().max())
if require_grad:
o3.backward(do, retain_graph=True)
o4.backward(do2, retain_graph=True)
k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad
print((k_grad2-k_grad).abs().max())
print((v_grad2-v_grad).abs().max())
print((beta_grad2-beta_grad).abs().max())
breakpoint()

View File

@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_gla
from .chunk_fuse import fused_chunk_gla
from .recurrent_fuse import fused_recurrent_gla
__all__ = [
'chunk_gla',
'fused_chunk_gla',
'fused_recurrent_gla'
]

734
finetune/lora/v6/fla/ops/gla/chunk.py vendored Normal file
View File

@ -0,0 +1,734 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from fla.ops.utils import chunk_reversed_cumsum_fwd
from fla.utils import contiguous
@triton.autotune(
configs=[
triton.Config({'BS': 16}, num_warps=2),
triton.Config({'BS': 16}, num_warps=4),
triton.Config({'BS': 16}, num_warps=8),
triton.Config({'BS': 32}, num_warps=2),
triton.Config({'BS': 32}, num_warps=4),
triton.Config({'BS': 32}, num_warps=8),
triton.Config({'BS': 64}, num_warps=2),
triton.Config({'BS': 64}, num_warps=4),
triton.Config({'BS': 64}, num_warps=8),
],
key=['S']
)
@triton.jit
def chunk_gla_fwd_kernel_cum(
s,
o,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr
):
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_gla_fwd_kernel_h(
k,
v,
g,
h,
h0,
ht,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BK, BT]
b_g = tl.load(p_g, boundary_check=(0, 1))
if i_t < NT - 1:
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
else:
b_gn = tl.min(b_g, axis=1)
b_h *= tl.exp(b_gn)[:, None]
b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
b_h += tl.dot(b_k, b_v, allow_tf32=False)
if STORE_FINAL_STATE:
p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_gla_fwd_kernel_intra(
q,
k,
g,
A,
s_k_h,
s_k_t,
s_k_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
n_bh = tl.num_programs(2)
if i_i > i_j:
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)
# [BK, BC]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
# [BC, BC]
b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
elif i_i == i_j:
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
o_i = tl.arange(0, BC)
o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
for j in range(0, BC):
# [BK,]
b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
# [BC,]
b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)
b_A = tl.where(o_i >= j, b_A, 0.)
tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
p_k = tl.advance(p_k, (K,))
p_gk = tl.advance(p_gk, (K,))
@triton.jit
def chunk_gla_fwd_kernel_inter(
q,
v,
g,
h,
o,
A,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_o = tl.zeros([BT, BV], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
# [BT, BK]
b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
# works but dkw, owing to divine benevolence
# [BT, BV]
if i_k >= 0:
b_o += tl.dot(b_qg, b_h, allow_tf32=False)
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BT]
b_A = tl.load(p_A, boundary_check=(0, 1))
b_o += tl.dot(b_A, b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_gla_bwd_kernel_dh(
q,
g,
do,
dh,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BK, BV]
b_dh *= tl.exp(b_gn)[:, None]
# [BK, BT]
b_g = tl.load(p_g, boundary_check=(0, 1))
b_q = (b_q * tl.exp(b_g)).to(b_q.dtype)
# [BK, BV]
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
@triton.jit
def chunk_gla_bwd_kernel_inter(
k,
v,
h,
g,
A,
do,
dh,
dq,
dk,
dv,
dA,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
# [BT, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
b_k = (b_k * b_gn).to(b_k.dtype)
# [BT, BT]
b_A = tl.load(p_A, boundary_check=(0, 1))
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
if i_k == 0:
b_dv += tl.dot(b_A, b_do, allow_tf32=False)
b_do = (b_do * scale).to(b_do.dtype)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BT, BT]
b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
# [BT, BK]
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
b_dq = b_dq * tl.exp(b_gk)
b_dk = b_dk * b_gn
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
# [BT, BT]
b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
if i_k == 0:
tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_gla_bwd_kernel_intra(
q,
k,
g,
dA,
dq,
dk,
dg,
s_k_h,
s_k_t,
s_k_d,
T: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_t, i_i = i_c // NC, i_c % NC
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BC, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
b_dq = tl.zeros([BC, BK], dtype=tl.float32)
for i_j in range(0, i_i):
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
# [BC, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
# [BC, BC]
b_dA = tl.load(p_dA, boundary_check=(0, 1))
# [BC, BK]
b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
b_dq *= tl.exp(b_g - b_gn[None, :])
o_i = tl.arange(0, BC)
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
for j in range(0, BC):
p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
p_gkj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
# [BC,]
b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
# [BK,]
b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32)
# [BC, BK]
m_i = o_i[:, None] >= j
# [BC, BK]
b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.)
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BC, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_dk = tl.zeros([BC, BK], dtype=tl.float32)
for i_j in range(i_i + 1, NC):
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype)
# [BC, BC]
b_dA = tl.load(p_dA, boundary_check=(0, 1))
# [BC, BK]
b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
b_dk *= tl.exp(b_gn[None, :] - b_gk)
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
for j in range(0, BC):
p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
p_gqj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
# [BC,]
b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
# [BK,]
b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
# [BC, BK]
m_i = o_i[:, None] <= j
b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
b_dg = b_q * b_dq - b_k * b_dk
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
class ChunkGLAFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level):
B, H, T, K, V = *q.shape, v.shape[-1]
BT, BC = 64, 16
BK = min(64, triton.next_power_of_2(K))
BV = min(64, triton.next_power_of_2(V))
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
num_warps = 4 if BK == 64 else 2
num_stages = 1
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
h = q.new_empty(B, H, NT * K, V)
grid = (NV, NK, B * H)
chunk_gla_fwd_kernel_h[grid](
k, v, g, h, h0, ht,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=h0 is not None,
STORE_FINAL_STATE=ht is not None,
num_warps=num_warps,
num_stages=num_stages
)
return h
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float)
g_org, g = g, torch.empty_like(g, dtype=torch.float)
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
# keep cummulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_gla_fwd_kernel_cum[grid](
g_org, g,
g.stride(1), g.stride(2), g.stride(3),
T=T, S=K, BT=BT
)
h = fwd_inner(
q=q, k=k, v=v, g=g,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
h0=initial_state if initial_state is not None else None,
ht=final_state if final_state is not None else None
)
A = q.new_zeros(NK, B, H, T, BT)
grid = (NK, NT * NC * NC, B * H)
chunk_gla_fwd_kernel_intra[grid](
q, k, g, A,
k.stride(1), k.stride(2), k.stride(3),
scale,
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
num_warps=num_warps,
num_stages=num_stages
)
A = A.sum(0, dtype=A.dtype)
o = torch.empty_like(v)
grid = (NV, NT, B * H)
chunk_gla_fwd_kernel_inter[grid](
q, v, g, h, o, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
scale,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
if checkpoint_level >= 1:
del g
g = g_org
if checkpoint_level > 1:
del h
h, initial_state = None, None
ctx.save_for_backward(q, k, v, g, h, initial_state, A)
ctx.BT = BT
ctx.scale = scale
ctx.checkpoint_level = checkpoint_level
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, dht=None):
q, k, v, g, h, initial_state, A = ctx.saved_tensors
B, H, T, K, V = *q.shape, v.shape[-1]
BT, BC = ctx.BT, 16
BK = min(64, triton.next_power_of_2(K))
BV = min(64, triton.next_power_of_2(V))
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
NK = triton.cdiv(K, BK)
num_warps = 4 if BK == 64 else 2
num_stages = 1
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
h = q.new_empty(B, H, NT * K, V)
grid = (NV, NK, B * H)
chunk_gla_fwd_kernel_h[grid](
k, v, g, h, h0, ht,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=h0 is not None,
STORE_FINAL_STATE=ht is not None,
num_warps=num_warps,
num_stages=num_stages
)
return h
def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, NT, scale):
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
dh = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_gla_bwd_kernel_dh[grid](
q, g, do, dh,
q.stride(1), q.stride(2), q.stride(3),
do.stride(1), do.stride(2), do.stride(3),
dh.stride(1), dh.stride(2), dh.stride(3),
scale,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
return dh
if ctx.checkpoint_level >= 1:
# save the original g and compute its fp32 cumsum during the backward pass for memory consideration
g_org, g = g, torch.zeros_like(g, dtype=torch.float)
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
# keep cummulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_gla_fwd_kernel_cum[grid](
g_org, g,
g.stride(1), g.stride(2), g.stride(3),
T=T, S=K, BT=BT
)
# rerun the forward pass to get h if checkpoint_level >= 1
if ctx.checkpoint_level > 1:
h = fwd_inner(
q=q, k=k, v=v, g=g,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
h0=initial_state if initial_state is not None else None,
ht=None
)
scale = ctx.scale
dh = bwd_inner(
q, g, do,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
scale=scale
)
dq = torch.empty_like(q, dtype=torch.float)
dk = torch.empty_like(k, dtype=torch.float)
dg = torch.empty_like(k, dtype=torch.float)
dv = v.new_empty(NK, *v.shape)
dA = q.new_zeros(B, H, T, BT)
grid = (NK, NT, B * H)
chunk_gla_bwd_kernel_inter[grid](
k, v, h, g, A, do, dh, dq, dk, dv, dA,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
scale,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
dv = dv.sum(0, dtype=dv.dtype)
grid = (NK, NT * NC, B * H)
chunk_gla_bwd_kernel_intra[grid](
q, k, g, dA, dq, dk, dg,
k.stride(1), k.stride(2), k.stride(3),
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
num_warps=num_warps,
num_stages=num_stages
)
dq = dq.to(q.dtype)
dk = dk.to(q.dtype)
# reversed cumsum, equivalent to:
#
# def reversed_cumsum(x, dim=-1):
# c = x.cumsum(dim)
# return x + c.index_select(dim, x.new_tensor([c.shape[dim]-1], dtype=torch.long)) - c
dg = chunk_reversed_cumsum_fwd(dg).to(k.dtype)
return dq, dk, dv, dg, None, None, None, None
def chunk_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
scale: Optional[int] = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
checkpoint_level: Optional[int] = 2
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `(B, H, T, K)`
k (torch.Tensor):
keys of shape `(B, H, T, K)`
v (torch.Tensor):
values of shape `(B, H, T, V)`
g (torch.Tensor):
Forget gates of shape `(B, H, T, K)` applied to keys.
scale (Optional[int]):
Scale factor for the GLA attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `(B, H, K, V)`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
checkpoint_level (Optional[int]):
Checkpointing level; higher values will save more memories and do more recomputations during backward.
Default: `0`:
- Level `0`: no memory saved, no recomputation.
- Level `1`: recompute the fp32 cumulative values during backward.
- Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
"""
assert checkpoint_level in [0, 1, 2]
if scale is None:
scale = q.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)
return o, final_state

View File

@ -0,0 +1,548 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Songlin Yang
# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635
# on-the-fly computation without materializing hidden statets into HBMs
from typing import Tuple
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.ops.gla.chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum,
prepare_qg_kg)
from fla.utils import contiguous
inv_ln2 = 1.44269504
@triton.jit
def fused_chunk_gla_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_K]
v, # value [B, H, L, D_head_V]
g, # cumulative sum of log decay [B, H, L, D_head_K]
o, # output [B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_h = tl.zeros([BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
for i in range(0, tl.cdiv(T, BT)):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
if CHECK and i == 0:
b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
else:
b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
p_db += BT * DK
if STORE_FINAL_STATE:
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_chunk_gla_bwd_kernel(
q, k, v, g,
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BV, BK]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
for i in range(0, tl.cdiv(T, BT)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2
d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
# [DV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, DV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [DV, DK]
if CHECK and i == 0:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
else:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# sync threads
b_h = None
tl.debug_barrier()
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
# cum = tl.zeros([BK], dtype=tl.float32)
for i in range(1, tl.cdiv(T, BT) + 1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
# [DK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, DV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)
# inter-chunk
# [DK, DV]
if CHECK and i == 1:
b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
else:
b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def fwd_inner_chunk(
q, k, g, A,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
DK: tl.constexpr, # D_head_K
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
o_i = tl.arange(0, BT)
p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
for i in range(BT):
_q = tl.load(p_q, mask=mask, other=0) * scale
gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
s = _q[None, :] * b_k * tl.math.exp2(gq[None, :] - b_g)
score = tl.sum(s, axis=1)
score = tl.where(o_i <= i, score, 0)
tl.store(p_A, score.to(p_A.dtype.element_ty))
p_q += DK
p_gq += DK
p_A += BT
@triton.jit
def bwd_inner_chunk(
q,
k,
g,
dA,
dq,
dk,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
DK: tl.constexpr, # D_head_K
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
o_i = tl.arange(0, BT)
p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
for i in range(BT):
_q = tl.load(p_q, mask=mask, other=0)
gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
score = tl.math.exp2(gq[None, :] - b_g)
score = tl.where(o_i[:, None] <= i, score, 0)
_dA = tl.load(p_dA)
_dA = tl.where(o_i <= i, _dA, 0)
b_dk += (_dA[:, None] * score * _q[None, :])
b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0)
tl.store(p_dq, b_dq, mask=mask)
p_q += DK
p_dq += DK
p_gq += DK
p_dA += BT
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
class FusedChunkGLAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
ctx.g_dtype = g.dtype
g_original = g
# cumulative decay should be in float32, otherwise the err will be accumulated and amplified.
g = torch.empty_like(g, dtype=torch.float32)
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
ctx.scale = scale
# inter-chunk
BT = 16 # chunk_size
BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
num_stages = 1
num_warps = 2
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
q_g = torch.empty_like(q)
k_g = torch.empty_like(k)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
fwd_decay_cumsum[grid](
g_original,
g,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
)
prepare_qg_kg[grid](
q, k, g, q_g, k_g,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float, requires_grad=False)
else:
final_state = None
# the bug still exists even for Triton 2.2 on H100 GPUs
# so we always enable initial checks
CHECK = True
if version.parse(triton.__version__) < version.parse('2.2.0'):
import warnings
warnings.warn(
"Triton<2.2.0 detected for running this kernel, "
"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
"that lead to significant precision loss. "
"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
)
CHECK = True
grid = (NV, NK, batch_size * n_heads)
fused_chunk_gla_fwd_kernel[grid](
q_g, k_g, v, g, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
CHECK=CHECK,
num_warps=num_warps,
num_stages=num_stages
)
o = o.sum(0)
# intra-chunk
chunk_size = 16
num_chunk = seq_len // chunk_size
v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
BK = min(d_head_qk, 64)
NK = triton.cdiv(d_head_qk, BK)
A = q.new_empty(NK, batch_size, n_heads, triton.cdiv(seq_len, BT), BT, BT)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
fwd_inner_chunk[grid](
q, k, g, A,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_stages=3,
num_warps=4
)
A = A.sum(0)
o2 = A @ v2
o2 = rearrange(o2, 'b h n c d -> b h (n c) d')
# combine inner and inter
o.add_(o2)
ctx.save_for_backward(q, k, v, g_original, A, initial_state)
ctx.CHECK = CHECK
return o.to(v), final_state
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, d_final_state=None):
q, k, v, g_origin, A, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = ctx.scale
# recomputation
# inter-chunk
BT = 16 # chunk_size
g = torch.empty_like(g_origin, dtype=torch.float32)
BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
q_g = torch.empty_like(q)
k_g = torch.empty_like(k)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
fwd_decay_cumsum[grid](
g_origin,
g,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
)
prepare_qg_kg[grid](
q, k, g, q_g, k_g,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
)
# inter-chunk
BT = 16
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 2
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_gla_bwd_kernel[grid](
q_g, k_g, v, g, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
# clamp_min=-3,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
CHECK=ctx.CHECK,
num_warps=num_warps,
num_stages=num_stages,
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
# intra chunk
num_chunk = seq_len // BT
v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk)
dA2 = (do2 @ v2.transpose(-2, -1)) * scale
dv2 = A.transpose(-1, -2) @ do2
dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk)
BK = min(triton.next_power_of_2(d_head_qk), 16)
NK = triton.cdiv(d_head_qk, BK)
dk2 = torch.empty_like(k)
dq2 = torch.empty_like(q)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
bwd_inner_chunk[grid](
q, k, g,
dA2, dq2, dk2,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, BK=BK,
num_warps=1,
num_stages=3
)
BK = min(triton.next_power_of_2(d_head_qk), 32)
NK = triton.cdiv(d_head_qk, BK)
dg = torch.empty_like(g, dtype=torch.float32)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
bwd_decay_global_cumsum[grid](
dq2, dq, dk2, dk, q, k, g, dg,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, BK=BK,
num_warps=1,
num_stages=1
)
dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)
def rev_cumsum_exclusive(x):
cumsum_x = x.cumsum(-2)
rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x
return rev_cumsum_x
rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])
dg.add_(rev_cumsum_dg.unsqueeze(-2))
dv.add_(dv2)
dg = rearrange(dg, 'b h n c d -> b h (n c) d')
return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None
def pad(x, chunk_size=16):
seq_len = x.shape[-2]
padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size
if x.shape[-2] % chunk_size != 0:
x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len))
return x
def ceildiv(a, b):
return -(a // -b)
def fused_chunk_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
scale: int = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if scale == -1:
scale = q.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
seq_len = q.shape[-2]
q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
o, final_state = FusedChunkGLAFunction.apply(
q, k, v, g, scale, initial_state, output_final_state)
o = o[..., :seq_len, :]
return o, final_state

View File

@ -0,0 +1,138 @@
import triton
import triton.language as tl
inv_ln2 = 1.44269504
@triton.jit
def fwd_decay_cumsum(
g,
g_o,
s_qk_h,
s_qk_t,
s_qk_d,
B,
H,
T,
scale,
BT: tl.constexpr,
BK: tl.constexpr,
DK: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
cum_decay = tl.zeros([BK], dtype=tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
for i in range(BT):
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
cum_decay += _g * inv_ln2
tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)
p_g += DK
p_go += DK
@triton.jit
def prepare_qg_kg(
q,
k,
g,
qg,
kg,
s_qk_h,
s_qk_t,
s_qk_d,
B,
H,
T,
scale,
BT: tl.constexpr,
BK: tl.constexpr,
DK: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
mask = (i_k * BK + tl.arange(0, BK)) < DK
last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))
for i in range(BT):
_q = tl.load(p_q, mask=mask, other=0)
_k = tl.load(p_k, mask=mask, other=0)
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
_q *= tl.math.exp2(_g) * scale
_k *= tl.math.exp2(last_decay - _g)
tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)
tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)
p_q += DK
p_g += DK
p_k += DK
p_kg += DK
p_qg += DK
@triton.jit
def bwd_decay_global_cumsum(
dq_inner,
dq_inter,
dk_inner,
dk_inter,
q, k, g, dg,
s_qk_h,
s_qk_t,
s_qk_d,
B,
H,
T,
scale,
BT: tl.constexpr,
BK: tl.constexpr,
DK: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
last_g = tl.zeros([BK], dtype=tl.float32)
for j in range(BT-1, -1, -1):
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
if j == (BT-1):
last_g = _g
_dq1 = tl.load(p_dq_inner, mask=mask, other=0)
_dq2 = tl.load(p_dq_inter, mask=mask, other=0)
_dq2 *= tl.math.exp2(_g)
_dq = _dq1 + _dq2
tl.store(p_dq_inter, _dq, mask=mask)
_dk1 = tl.load(p_dk_inner, mask=mask, other=0)
_dk2 = tl.load(p_dk_inter, mask=mask, other=0)
_dk2 *= tl.math.exp2(last_g - _g)
_dk = _dk1 + _dk2
tl.store(p_dk_inter, _dk, mask=mask)
_q = tl.load(p_q, mask=mask, other=0)
_k = tl.load(p_k, mask=mask, other=0)
_dg = _dq * _q - _dk * _k
cum_grad_dg += _dg
tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
p_g -= DK
p_k -= DK
p_q -= DK
p_dq_inner -= DK
p_dk_inner -= DK
p_dq_inter -= DK
p_dk_inter -= DK
p_dg -= DK

116
finetune/lora/v6/fla/ops/gla/naive.py vendored Normal file
View File

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
from fla.ops.gla.recurrent_fuse import fused_recurrent_gla
def ceildiv(a, b):
return -(a // -b)
def naive_recurrent_gla(
q,
k,
v,
gk,
initial_state=None,
output_final_state=False,
causal=True
):
orig_dtype = q.dtype
q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk))
batch_size, n_heads, seq_len, d_head_k = q.shape
_, _, _, d_head_v = v.shape
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
o = torch.zeros_like(v)
scale = d_head_k ** -0.5
if initial_state is not None:
h += initial_state
for i in range(seq_len):
q_i = q[:, :, i, :] * scale
k_i = k[:, :, i]
v_i = v[:, :, i, :]
gk_i = gk[:, :, i].exp()
kv_i = k_i[..., None] * v_i[..., None, :]
h = h * gk_i[..., None] + kv_i
o_i = (q_i[..., None] * h).sum(-2)
o[:, :, i] = o_i
if causal:
return o.to(orig_dtype), h
else:
o_reverse = torch.zeros_like(v)
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
for i in range(seq_len-1, -1, -1):
q_i = q[:, :, i, :] * scale
k_i = k[:, :, i]
v_i = v[:, :, i, :]
gk_i = gk[:, :, i].exp()
kv_i = k_i[..., None] * v_i[..., None, :]
h = h * gk_i[..., None] + kv_i
o_i = (q_i[..., None] * h).sum(-2)
o_reverse[:, :, i] = o_i
return o, o_reverse
if __name__ == "__main__":
B = 4
H = 4
L = 512
D = 128
dtype = torch.float32
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True)
g = F.logsigmoid(torch.rand(B, H, L, D)).cuda(
).clamp_min(-1).to(torch.float32).requires_grad_(True)
do = torch.rand_like(v).cuda()
do2 = torch.rand_like(v).cuda()
intial_state = torch.rand(B, H, D, D).cuda()
ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False)
ref.backward(do, retain_graph=True)
ref_rev.backward(do2, retain_graph=True)
ref_dq, q.grad = q.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dv, v.grad = v.grad.clone(), None
ref_dg, g.grad = g.grad.clone(), None
tri, tri_rev = fused_recurrent_gla(
q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False)
tri.backward(do, retain_graph=True)
tri_rev.backward(do2, retain_graph=True)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
tri_dg, g.grad = g.grad.clone(), None
assert ref.allclose(tri, 0, 1e-5), breakpoint()
assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint()
assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
# tri = fused_chunk_gla(q, k, v, g)
# tri.backward(do, retain_graph=True)
# tri_dq, q.grad = q.grad.clone(), None
# tri_dk, k.grad = k.grad.clone(), None
# tri_dv, v.grad = v.grad.clone(), None
# tri_dg, g.grad = g.grad.clone(), None
# assert ref.allclose(tri, 0, 1e-5), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
# assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
# breakpoint()
print("Pass")

View File

@ -0,0 +1,404 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.jit
def fused_recurrent_gla_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_K]
v, # value [B, H, L, D_head_V]
gk, # log gate [B, H, L, D_head_K]
gv, # log gate [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
final_state, # final hidden state [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
USE_GK: tl.constexpr, # whether to use gk
USE_GV: tl.constexpr, # whether to use gv
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
p_k = k + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
p_v = v + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
h = tl.zeros([BV, BK], dtype=tl.float32)
mask_kv = mask_bk[None, :] & mask_bv[:, None]
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
if USE_GK:
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
h = h * _gk[None, :]
if USE_GV:
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
h = h * _gv[:, None]
h += _k[None, :] * _v[:, None]
_o = h * _q[None, :]
_o = tl.sum(_o, axis=1)
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += -DK if REVERSE else DK
p_k += -DK if REVERSE else DK
p_o += -DV if REVERSE else DV
p_v += -DV if REVERSE else DV
if USE_GK:
p_gk += -DK if REVERSE else DK
if USE_GV:
p_gv += -DV if REVERSE else DV
if STORE_FINAL_STATE:
p_final_s = final_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_gla_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
gk, # log gate [B, H, L, D_head_K] \alpha
gv, # log gate [B, H, L, D_head_V] \bete
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
USE_GK: tl.constexpr, # whether to use gk
USE_GV: tl.constexpr, # whether to use gv
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
p_k = k + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
p_v = v + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
p_do = do + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
mask_kv = mask_bk[:, None] & mask_bv[None, :]
h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[:, None]) * \
DV + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for i in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
if USE_GK:
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
h = h * _gk[:, None]
if USE_GV:
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
h = h * _gv[None, :]
h += _k[:, None] * _v[None, :]
_d_q = h * _do[None, :]
d_q = tl.sum(_d_q, axis=1) * scale
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
p_k += -DK if REVERSE else DK
p_v += -DV if REVERSE else DV
p_q += -DK if REVERSE else DK
p_do += -DV if REVERSE else DV
p_dq += -DK if REVERSE else DK
if USE_GK:
p_gk += -DK if REVERSE else DK
if USE_GV:
p_gv += -DV if REVERSE else DV
# sync threads
tl.debug_barrier()
p_q = q + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
p_k = k + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
p_do = do + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
p_v = v + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
d_h = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
d_h += _q[:, None] * _do[None, :]
d_k = tl.sum(d_h * _v[None, :], axis=1)
d_v = tl.sum(d_h * _k[:, None], axis=0)
if USE_GK:
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
d_h *= _gk[:, None]
if USE_GV:
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
d_h *= _gv[None, :]
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
p_do += DV if REVERSE else -DV
p_q += DK if REVERSE else -DK
p_k += DK if REVERSE else -DK
p_v += DV if REVERSE else -DV
p_dk += DK if REVERSE else -DK
p_dv += DV if REVERSE else -DV
if USE_GK:
p_gk += DK if REVERSE else -DK
if USE_GV:
p_gv += DV if REVERSE else -DV
class FusedRecurrentGLAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
# default scale
if scale is None:
scale = d_head_qk ** -0.5
if gk is not None:
gk = gk.float().exp()
if gv is not None:
gv = gv.float().exp()
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
o = q.new_empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=torch.float32)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
else:
final_state = None
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_gla_fwd_kernel[grid](
q, k, v, gk, gv, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None,
USE_GK=gk is not None,
USE_GV=gv is not None,
REVERSE=reverse,
num_warps=num_warps,
num_stages=num_stages
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)
ctx.scale = scale
ctx.reverse = reverse
# we do not need the gradient of the final state from the next chunk
# similiar to Trunctated BPTT
if final_state is not None:
final_state = final_state.detach()
return o.to(q.dtype), final_state
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, d_final_state=None):
q, k, v, gk, gv, initial_state, o = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = ctx.scale
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
dq = q.new_empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=torch.float32)
dk = q.new_empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=torch.float32)
dv = q.new_empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=torch.float32)
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_gla_bwd_kernel[grid](
q, k, v, gk, gv, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
REVERSE=ctx.reverse,
USE_GK=gk is not None,
USE_GV=gv is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
if gk is not None:
_dgk = dq * q.float() - dk * k.float()
if ctx.reverse:
dgk = _dgk.cumsum(-2)
else:
_dgk_cumsum = _dgk.cumsum(-2)
dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum
else:
dgk = None
if gv is not None:
_dgv = do.float() * o.float() - dv * v.float()
if ctx.reverse:
dgv = _dgv.cumsum(-2)
else:
_dgv_cumsum = _dgv.cumsum(-2)
dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum
else:
dgv = None
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None
# if scale is None, use d_head_qk ** -0.5 by default. Otherwise specify the scale yourself. e.g. scale = 1.0
def fused_recurrent_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
gk: torch.Tensor = None,
gv: torch.Tensor = None,
scale: int = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
causal: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
if scale == -1:
scale = q.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
if causal:
o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)
return o, final_state
else:
# do not support initial_state yet. looks very strange for bidirectional modeling
assert initial_state is None
assert output_final_state is False
o, final_state = FusedRecurrentGLAFunction.apply(
q, k, v, gk, gv, scale, initial_state, output_final_state, False)
o_reversed, final_state = FusedRecurrentGLAFunction.apply(
q, k, v, gk, gv, scale, initial_state, output_final_state, True)
return [o, o_reversed]

View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_hgrn
from .recurrent_fuse import fused_recurrent_hgrn
__all__ = [
'chunk_hgrn',
'fused_recurrent_hgrn'
]

373
finetune/lora/v6/fla/ops/hgrn/chunk.py vendored Normal file
View File

@ -0,0 +1,373 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Yu Zhang, Songlin Yang
# this function implements the chunkwise form of HGRN, inspired by
# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html)
# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan
# from tests on H800, with B, H, D = 16, 4, 128, we see that the chunk can be greatly faster than the recurrent:
#
# Performance:
# seq_len chunk recurrent chunk_bwd recurrent_bwd
# 0 128.0 0.039360 0.061056 0.312160 0.205008
# 1 256.0 0.045824 0.123712 0.308784 0.297696
# 2 512.0 0.058688 0.241952 0.310720 0.626528
# 3 1024.0 0.088288 0.476992 0.313184 1.333152
# 4 2048.0 0.169472 0.943264 0.452464 2.724864
# 5 4096.0 0.329920 1.886144 0.881600 5.551520
# 6 8192.0 0.647872 3.755040 1.740496 11.117184
# 7 16384.0 1.272064 7.520576 3.446608 22.362528
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.jit
def chunk_hgrn_fwd_kernel_h(
x,
g,
gc,
o,
h0,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr
):
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
p_x = x + i_bh * T * D + i_t * BT * D + o_d
p_g = g + i_bh * T * D + i_t * BT * D + o_d
p_gc = gc + i_bh * T * D + i_t * BT * D + o_d
p_o = o + i_bh * T * D + i_t * BT * D + o_d
b_h = tl.zeros([BD], dtype=tl.float32)
b_gc = tl.zeros([BD], dtype=tl.float32)
if USE_INITIAL_STATE:
if i_t == 0:
b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
for i in range(0, BT):
mask_t = mask & ((i_t * BT + i) < T)
b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)
b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)
b_h = tl.exp(b_g) * b_h + b_x
b_gc = b_gc + b_g
tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)
p_x += D
p_g += D
p_gc += D
p_o += D
@triton.jit
def chunk_hgrn_fwd_kernel_o(
gc,
o,
s_h,
s_t,
s_d,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr
):
i_d, i_bh = tl.program_id(0), tl.program_id(1)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
for i_t in range(1, tl.cdiv(T, BT)):
p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
# [BD,]
b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)
# [BT, BD]
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
b_o = b_o + tl.exp(b_gc) * b_h0[None, :]
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.jit
def chunk_hgrn_bwd_kernel_h(
g,
gc,
dx,
do,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr
):
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
BC = min(BT, T - i_t * BT)
NT = tl.num_programs(1)
p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d
p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d
p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d
p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d
if i_t == NT - 1:
b_gc = tl.zeros([BD], dtype=tl.float32)
else:
b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)
b_dh = tl.zeros([BD], dtype=tl.float32)
for _ in range(BC - 1, -1, -1):
tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
b_gc = b_gc + b_g
b_dh = b_dh + b_do
b_dx = b_dh
b_dh = b_dh * tl.exp(b_g)
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
p_g -= D
p_gc -= D
p_dx -= D
p_do -= D
@triton.jit
def chunk_hgrn_bwd_kernel_o(
g,
gc,
o,
dx,
dg,
s_h,
s_t,
s_d,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr
):
i_d, i_bh = tl.program_id(0), tl.program_id(1)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))
p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
# [BD,]
mask_t = mask & ((i_t + 1) * BT < T)
b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)
# [BT, BD]
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)
b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)
b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]
b_dg = b_o * b_dx * tl.exp(b_g)
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
class ChunkHGRNFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, x, g, initial_state=None, output_final_state=False):
B, H, T, D = x.shape
BT, BD = 128, min(64, triton.next_power_of_2(D))
num_warps = 8 if BD == 64 else 4
gc = torch.empty_like(g, dtype=torch.float)
o = torch.empty_like(x, dtype=torch.float)
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
chunk_hgrn_fwd_kernel_h[grid](
x, g, gc, o, initial_state,
T, D,
BT=BT,
USE_INITIAL_STATE=initial_state is not None
)
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
chunk_hgrn_fwd_kernel_o[grid](
gc, o,
o.stride(1), o.stride(2), o.stride(3),
T, D,
BT=BT, BD=BD,
num_warps=num_warps
)
final_state = None
if output_final_state:
final_state = o[:, :, -1].clone()
o = o.to(x.dtype)
ctx.save_for_backward(g, o, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, dht=None):
g, o, initial_state = ctx.saved_tensors
B, H, T, D = do.shape
BT, BD = 128, min(64, triton.next_power_of_2(D))
num_warps = 8 if BD == 64 else 4
gc = torch.empty_like(g, dtype=torch.float)
dx = torch.empty_like(o)
dg = torch.empty_like(g)
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
chunk_hgrn_bwd_kernel_h[grid](
g, gc, dx, do,
T, D,
BT=BT
)
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
chunk_hgrn_bwd_kernel_o[grid](
g, gc, o, dx, dg,
o.stride(1), o.stride(2), o.stride(3),
T, D,
BT=BT, BD=BD,
num_warps=num_warps
)
if initial_state is not None:
dg[:, :, 0] = initial_state * dx[:, :, 0] * g[:, :, 0].exp()
return dx, dg, None, None
def chunk_hgrn(
x: torch.Tensor,
g: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)
return o, final_state
if __name__ == '__main__':
import torch.nn.functional as F
from fla.ops.hgrn.naive import naive_recurrent_hgrn
from fla.ops.hgrn.recurrent_fuse import fused_recurrent_hgrn
B, H, T, D = 8, 4, 512, 128
dtype = torch.bfloat16
torch.manual_seed(42)
# [batch_size, n_heads, seq_len, d_head]
x = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
g = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
x, g = (1 - g.sigmoid()) * x, F.logsigmoid(g)
print(f'x:\t{float(x.min()):>10.6f}\t{float(x.max()):>10.6f}')
print(f'g:\t{float(g.min()):>10.6f}\t{float(g.max()):>10.6f}')
x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
print(f"DTYPE:\t{x.dtype}")
do = torch.randn_like(x)
h0 = torch.randn_like(x[:, :, 0])
ref, ref_ht = naive_recurrent_hgrn(x, g, h0, output_final_state=True)
ref.backward(do)
ref_dx, x.grad = x.grad.clone(), None
ref_dg, g.grad = g.grad.clone(), None
tri, tri_ht = fused_recurrent_hgrn(x, g, h0, output_final_state=True)
tri.backward(do)
tri_dx, x.grad = x.grad.clone(), None
tri_dg, g.grad = g.grad.clone(), None
print(" \t DIFF\t MAX")
print(' o\t', f"{float((ref - tri).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
print('ht\t', f"{float((ref_ht[0] - tri_ht[0]).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
print('dx\t', f"{float((ref_dx - tri_dx).abs().max()):>10.6f}\t{float(ref_dx.max()):>10.6f}")
print('dg\t', f"{float((ref_dg - tri_dg).abs().max()):>10.6f}\t{float(ref_dg.max()):>10.6f}")
print('Done!')
@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['seq_len'],
# different possible values for `x_name`
x_vals=[128 * 2 ** i for i in range(0, 8)],
# argument name whose value corresponds to a different line in the plot
line_arg='provider',
# possible values for `line_arg``
line_vals=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
# label name for the lines
line_names=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
# line styles
styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
ylabel="Execution Time (ms)", # label name for the y-axis
# name for the plot. Used also as a file name for saving the plot.
plot_name="Performance",
args={},
)
)
def benchmark(seq_len, provider):
dtype = torch.bfloat16
B, H, D = 16, 4, 128
x = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda')
g = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda').sigmoid()
x = (1 - g) * x
x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
do = torch.randn_like(x, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
results = 0, 0, 0
if provider == 'chunk':
results = triton.testing.do_bench(lambda: chunk_hgrn(x, g), quantiles=quantiles)
if provider == 'recurrent':
results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g), quantiles=quantiles)
if provider == 'chunk_bwd':
results = triton.testing.do_bench(lambda: chunk_hgrn(x, g)[0].backward(do), quantiles=quantiles)
if provider == 'recurrent_bwd':
results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g)[0].backward(do), quantiles=quantiles)
return results
benchmark.run(print_data=True)

31
finetune/lora/v6/fla/ops/hgrn/naive.py vendored Normal file
View File

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
from typing import Optional
import torch
def naive_recurrent_hgrn(
x: torch.Tensor,
g: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: Optional[bool] = False
) -> torch.Tensor:
dtype = x.dtype
x, g = map(lambda i: i.float(), (x, g))
B, H, T, D = x.shape
h = torch.zeros(B, H, D, dtype=torch.float, device=x.device)
o = torch.zeros_like(x)
final_state = None
if initial_state is not None:
h += initial_state.detach()
for i in range(T):
h = g[:, :, i].exp() * h + x[:, :, i]
o[:, :, i] = h
if output_final_state:
final_state = h
return o.to(dtype), final_state

View File

@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.jit
def fused_recurrent_hgrn_fwd_kernel(
x,
g,
o,
h0,
ht,
T: tl.constexpr,
D: tl.constexpr,
BD: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_d, i_bh = tl.program_id(0), tl.program_id(1)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
p_x = x + i_bh * T * D + o_d
p_g = g + i_bh * T * D + o_d
p_o = o + i_bh * T * D + o_d
b_h = tl.zeros([BD], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = h0 + i_bh * D + o_d
b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
for _ in range(0, T):
b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
b_h = tl.exp(b_g) * b_h + b_x
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
p_x += D
p_g += D
p_o += D
if STORE_FINAL_STATE:
p_ht = ht + i_bh * D + o_d
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.jit
def fused_recurrent_hgrn_bwd_kernel(
g,
o,
dx,
dg,
do,
h0,
T: tl.constexpr,
D: tl.constexpr,
BD: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr
):
i_d, i_bh = tl.program_id(0), tl.program_id(1)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
p_g = g + (i_bh * T + T - 1) * D + o_d
p_o = o + (i_bh * T + T - 2) * D + o_d
p_dx = dx + (i_bh * T + T - 1) * D + o_d
p_dg = dg + (i_bh * T + T - 1) * D + o_d
p_do = do + (i_bh * T + T - 1) * D + o_d
b_dh = tl.zeros([BD], dtype=tl.float32)
for i in range(T - 1, -1, -1):
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
if i > 0:
b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
elif USE_INITIAL_STATE:
b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
else:
b_o = tl.zeros([BD], dtype=tl.float32)
b_dh = b_dh + b_do
b_dx = b_dh
b_dh = b_dh * tl.exp(b_g)
b_dg = b_dh * b_o
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
p_g -= D
p_o -= D
p_dx -= D
p_dg -= D
p_do -= D
class FusedRecurrentHGRNFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, x, g, initial_state=None, output_final_state=False):
B, H, T, D = x.shape
final_state = None
if output_final_state:
final_state = x.new_empty(B, H, D)
o = torch.empty_like(x)
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
fused_recurrent_hgrn_fwd_kernel[grid](
x, g, o, initial_state, final_state,
T, D,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None
)
ctx.save_for_backward(g, o, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, dht=None):
g, o, initial_state = ctx.saved_tensors
B, H, T, D = do.shape
dx = torch.empty_like(o)
dg = torch.empty_like(g)
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
fused_recurrent_hgrn_bwd_kernel[grid](
g, o, dx, dg, do, initial_state,
T, D,
USE_INITIAL_STATE=initial_state is not None,
)
return dx, dg, None, None
def fused_recurrent_hgrn(
x: torch.Tensor,
g: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)
return o, final_state

View File

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_linear_attn
from .chunk_fuse import fused_chunk_linear_attn
from .recurrent_fuse import fused_recurrent_linear_attn
__all__ = [
'chunk_linear_attn',
'fused_chunk_linear_attn',
'fused_recurrent_linear_attn'
]

View File

@ -0,0 +1,359 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
@torch.jit.script
def normalize_output(q, k, o):
k = k.transpose(-2, -1)
k = k.cumsum(-1)
k = k.transpose(-2, -1)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-5)
@triton.jit
def chunk_linear_attn_fwd_kernel_h(
k,
v,
h,
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BK, BV]
b_h += tl.dot(b_k, b_v, allow_tf32=False)
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_linear_attn_fwd_kernel_o(
q,
k,
v,
h,
o,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_s = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_o += tl.dot(b_q, b_h, allow_tf32=False)
b_s += tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_linear_attn_bwd_kernel_dh(
q,
do,
dh,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, V]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BV]
b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
@triton.jit
def chunk_linear_attn_bwd_kernel_dqkv(
q,
k,
v,
h,
do,
dh,
dq,
dk,
dv,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
o_i = tl.arange(0, BT)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BT]
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
# [BT, BK]
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
class ChunkLinearAttentionFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@contiguous
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
ctx.scale = scale
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
h = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_linear_attn_fwd_kernel_h[grid](
k, v, h, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NV, NT, B * H)
o = torch.empty_like(v)
chunk_linear_attn_fwd_kernel_o[grid](
q, k, v, h, o,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
ctx.save_for_backward(q, k, v, h)
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_ht=None):
q, k, v, h = ctx.saved_tensors
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
scale = ctx.scale
dh = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_linear_attn_bwd_kernel_dh[grid](
q, do, dh,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NK, NT, B * H)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = v.new_empty(NK, *v.shape)
num_stages = 1
num_warps = 4 if BK == 64 else 2
chunk_linear_attn_bwd_kernel_dqkv[grid](
q, k, v, h, do, dh, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
dv = dv.sum(0)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
def chunk_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
if scale == -1:
scale = q.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
if normalize:
o = normalize_output(q * scale, k, o)
return o, final_state

Some files were not shown because too many files have changed in this diff Show More