This commit is contained in:
		
							parent
							
								
									3488d22d22
								
							
						
					
					
						commit
						f05a4acb04
					
				
							
								
								
									
										311
									
								
								finetune/lora/v6/cuda/wkv6infctx_cuda.cu
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										311
									
								
								finetune/lora/v6/cuda/wkv6infctx_cuda.cu
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,311 @@
 | 
			
		||||
#include <stdio.h>
 | 
			
		||||
#include <assert.h>
 | 
			
		||||
#include "ATen/ATen.h"
 | 
			
		||||
typedef at::BFloat16 bf16;
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
 | 
			
		||||
                               const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, F *__restrict__ _s,
 | 
			
		||||
                               F *__restrict__ const _y)
 | 
			
		||||
{
 | 
			
		||||
    const int b = blockIdx.x / H;
 | 
			
		||||
    const int h = blockIdx.x % H;
 | 
			
		||||
    const int i = threadIdx.x;
 | 
			
		||||
    _u += h*_N_;
 | 
			
		||||
    _s += h*_N_*_N_ + i*_N_;
 | 
			
		||||
 | 
			
		||||
    __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
 | 
			
		||||
    float state[_N_];
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    u[i] = float(_u[i]);
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    for (int j = 0; j < _N_; j++) {
 | 
			
		||||
        state[j] = float(_s[j]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        w[i] = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        r[i] = float(_r[t]);
 | 
			
		||||
        k[i] = float(_k[t]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float v = float(_v[t]);
 | 
			
		||||
        float y = 0;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j+=4)
 | 
			
		||||
        {
 | 
			
		||||
            const float4& r_ = (float4&)(r[j]);
 | 
			
		||||
            const float4& k_ = (float4&)(k[j]);
 | 
			
		||||
            const float4& w_ = (float4&)(w[j]);
 | 
			
		||||
            const float4& u_ = (float4&)(u[j]);
 | 
			
		||||
            float4& s = (float4&)(state[j]);
 | 
			
		||||
            float4 x;
 | 
			
		||||
 | 
			
		||||
            x.x = k_.x * v;
 | 
			
		||||
            x.y = k_.y * v;
 | 
			
		||||
            x.z = k_.z * v;
 | 
			
		||||
            x.w = k_.w * v;
 | 
			
		||||
 | 
			
		||||
            y += r_.x * (u_.x * x.x + s.x);
 | 
			
		||||
            y += r_.y * (u_.y * x.y + s.y);
 | 
			
		||||
            y += r_.z * (u_.z * x.z + s.z);
 | 
			
		||||
            y += r_.w * (u_.w * x.w + s.w);
 | 
			
		||||
 | 
			
		||||
            s.x = s.x * w_.x + x.x;
 | 
			
		||||
            s.y = s.y * w_.y + x.y;
 | 
			
		||||
            s.z = s.z * w_.z + x.z;
 | 
			
		||||
            s.w = s.w * w_.w + x.w;
 | 
			
		||||
        }
 | 
			
		||||
        _y[t] = F(y);
 | 
			
		||||
    }
 | 
			
		||||
    #pragma unroll
 | 
			
		||||
    for (int j = 0; j < _N_; j++)
 | 
			
		||||
        _s[j] = F(state[j]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
__global__ void kernel_backward_111(const int B, const int T, const int C, const int H,
 | 
			
		||||
    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy,
 | 
			
		||||
    F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu, F *__restrict__ const _gs)
 | 
			
		||||
{
 | 
			
		||||
    const int b = blockIdx.x / H;
 | 
			
		||||
    const int h = blockIdx.x % H;
 | 
			
		||||
    const int i = threadIdx.x;
 | 
			
		||||
    _u += h*_N_;
 | 
			
		||||
    _s += h*_N_*_N_ + i;
 | 
			
		||||
 | 
			
		||||
    __shared__ float u_[_N_];
 | 
			
		||||
    __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    u_[i] = float(_u[i]);
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    const float u = u_[i];
 | 
			
		||||
 | 
			
		||||
    float state[_N_], scccc[_N_] = {0}, sdddd[_N_] = {0}, sssss[_N_] = {0}, swwww[_N_];
 | 
			
		||||
    for (int j = 0; j < _N_; j++) {
 | 
			
		||||
        state[j] = float(_s[j*_N_]);
 | 
			
		||||
        swwww[j] = 1.0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const int t_0 = b*T*C + h*_N_ + i;
 | 
			
		||||
    const int t_T_1 = t_0 + (T-1)*C;
 | 
			
		||||
    const int t_T = t_0 + T*C;
 | 
			
		||||
 | 
			
		||||
    float gu = 0;
 | 
			
		||||
    for (int t = t_0; t < t_T; t += C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        v[i] = float(_v[t]);
 | 
			
		||||
        gy[i] = float(_gy[t]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float k = float(_k[t]);
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        float gr = 0, gu_ = 0;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = state[j];
 | 
			
		||||
            float x = k * v[j];
 | 
			
		||||
 | 
			
		||||
            gr += (u * x + s) * gy[j];
 | 
			
		||||
            gu_ += x * gy[j];
 | 
			
		||||
            s = s * w + x;
 | 
			
		||||
        }
 | 
			
		||||
        _gr[t] = F(gr);
 | 
			
		||||
        gu += float(_r[t]) * gu_;
 | 
			
		||||
    }
 | 
			
		||||
    _gu[b*C + h*_N_ + i] = F(gu);
 | 
			
		||||
 | 
			
		||||
    for (int t = t_T_1; t >= t_0; t -= C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        v[i] = float(_v[t]);
 | 
			
		||||
        gy[i] = float(_gy[t]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float rr = float(_r[t]);
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        float gk = 0;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = scccc[j];
 | 
			
		||||
            float x = rr * gy[j];
 | 
			
		||||
            
 | 
			
		||||
            gk += (u * x + s) * v[j];
 | 
			
		||||
            s = x + s * w;
 | 
			
		||||
        }
 | 
			
		||||
        _gk[t] = F(gk);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int t = t_T_1; t >= t_0; t -= C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        r[i] = float(_r[t]);
 | 
			
		||||
        k[i] = float(_k[t]);
 | 
			
		||||
        w_[i] = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float gyy = float(_gy[t]);
 | 
			
		||||
        float gv = 0;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = sdddd[j];
 | 
			
		||||
            float x = gyy * r[j];
 | 
			
		||||
            
 | 
			
		||||
            gv += (u_[j] * x + s) * k[j];
 | 
			
		||||
            s = x + s * w_[j];
 | 
			
		||||
        }
 | 
			
		||||
        _gv[t] = F(gv);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int t = t_0; t < t_T; t += C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        r[i] = float(_r[t]);
 | 
			
		||||
        w_[i] = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float gyy = float(_gy[t]);
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& w = swwww[j];
 | 
			
		||||
            sssss[j] += gyy * w * r[j];
 | 
			
		||||
            w *= w_[j];
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    for (int j = 0; j < _N_; j++)
 | 
			
		||||
        _gs[b*H*_N_*_N_ + h*_N_*_N_ + i*_N_ + j] = F(sssss[j]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
__global__ void kernel_backward_222(const int B, const int T, const int C, const int H,
 | 
			
		||||
    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy,
 | 
			
		||||
    F *__restrict__ const _gw)
 | 
			
		||||
{
 | 
			
		||||
    const int b = blockIdx.x / H;
 | 
			
		||||
    const int h = blockIdx.x % H;
 | 
			
		||||
    const int i = threadIdx.x;
 | 
			
		||||
    _s += h*_N_*_N_ + i;
 | 
			
		||||
 | 
			
		||||
    __shared__ float v[_N_], gy[_N_];
 | 
			
		||||
    float state[_N_], saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0};
 | 
			
		||||
    for (int j = 0; j < _N_; j++) {
 | 
			
		||||
        state[j] = float(_s[j*_N_]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const int t_0 = b*T*C + h*_N_ + i;
 | 
			
		||||
    const int t_1 = t_0 + C;
 | 
			
		||||
    const int t_2 = t_0 + 2*C;
 | 
			
		||||
    const int t_T_1 = t_0 + (T-1)*C;
 | 
			
		||||
 | 
			
		||||
    for (int t = t_T_1; t > t_1; t -= C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        gy[i] = float(_gy[t]);
 | 
			
		||||
        v[i] = float(_v[t-2*C]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float r = float(_r[t]);
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t-C])));
 | 
			
		||||
        float sum = 0.0f;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = saaaa[j];
 | 
			
		||||
            s = (s + r * gy[j]) * w;
 | 
			
		||||
            sum += s * v[j];
 | 
			
		||||
        }
 | 
			
		||||
        sbbbb[(t-t_1)/C] = sum * float(_k[t-2*C]);
 | 
			
		||||
    }
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        gy[i] = float(_gy[t_1]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float r = float(_r[t_1]);
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t_0])));
 | 
			
		||||
        float sum = 0.0f;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = saaaa[j];
 | 
			
		||||
            s = (s + r * gy[j]) * w;
 | 
			
		||||
            sum += s * state[j];
 | 
			
		||||
        }
 | 
			
		||||
        sbbbb[0] = sum;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    float sss = sbbbb[0];
 | 
			
		||||
    _gw[t_0] = F(sss * -__expf(float(_w[t_0])));
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        gy[i] = float(_gy[t_1]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t_0])));
 | 
			
		||||
        float sum = 0.0f;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = scccc[j];
 | 
			
		||||
            s = (s + state[j]) * w;
 | 
			
		||||
            sum += s * gy[j];
 | 
			
		||||
        }
 | 
			
		||||
        sss += sbbbb[1] - (sum * float(_r[t_1]));
 | 
			
		||||
        _gw[t_1] = F(sss * -__expf(float(_w[t_1])));
 | 
			
		||||
    }
 | 
			
		||||
    for (int t = t_2; t < t_T_1; t += C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        gy[i] = float(_gy[t]);
 | 
			
		||||
        v[i] = float(_v[t-2*C]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t-C])));
 | 
			
		||||
        const float k = float(_k[t-2*C]);
 | 
			
		||||
        float sum = 0.0f;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = scccc[j];
 | 
			
		||||
            s = (s + k * v[j]) * w;
 | 
			
		||||
            sum += s * gy[j];
 | 
			
		||||
        }
 | 
			
		||||
        sss += sbbbb[(t-t_0)/C] - (sum * float(_r[t]));
 | 
			
		||||
        _gw[t] = F(sss * -__expf(float(_w[t])));
 | 
			
		||||
    }
 | 
			
		||||
    _gw[t_T_1] = 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *y)
 | 
			
		||||
{
 | 
			
		||||
    assert(H*_N_ == C);
 | 
			
		||||
    assert(_N_%4 == 0);
 | 
			
		||||
    kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, y);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs)
 | 
			
		||||
{
 | 
			
		||||
    assert(H*_N_ == C);
 | 
			
		||||
    assert(_N_%4 == 0);
 | 
			
		||||
    kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, gy, gr, gk, gv, gu, gs);
 | 
			
		||||
    kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, gy, gw);
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										22
									
								
								finetune/lora/v6/cuda/wkv6infctx_op.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								finetune/lora/v6/cuda/wkv6infctx_op.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include "ATen/ATen.h"
 | 
			
		||||
typedef at::BFloat16 bf16;
 | 
			
		||||
 | 
			
		||||
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y);
 | 
			
		||||
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs);
 | 
			
		||||
 | 
			
		||||
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) {
 | 
			
		||||
    cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), y.data_ptr<bf16>());
 | 
			
		||||
}
 | 
			
		||||
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) {
 | 
			
		||||
    cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gs.data_ptr<bf16>());
 | 
			
		||||
}
 | 
			
		||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
    m.def("forward", &forward, "wkv6state forward");
 | 
			
		||||
    m.def("backward", &backward, "wkv6state backward");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY(wkv6state, m) {
 | 
			
		||||
    m.def("forward", forward);
 | 
			
		||||
    m.def("backward", backward);
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										311
									
								
								finetune/lora/v6/cuda/wkv6state_cuda.cu
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										311
									
								
								finetune/lora/v6/cuda/wkv6state_cuda.cu
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,311 @@
 | 
			
		||||
#include <stdio.h>
 | 
			
		||||
#include <assert.h>
 | 
			
		||||
#include "ATen/ATen.h"
 | 
			
		||||
typedef at::BFloat16 bf16;
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
 | 
			
		||||
                               const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u,const F *__restrict__ _s,
 | 
			
		||||
                               F *__restrict__ const _y)
 | 
			
		||||
{
 | 
			
		||||
    const int b = blockIdx.x / H;
 | 
			
		||||
    const int h = blockIdx.x % H;
 | 
			
		||||
    const int i = threadIdx.x;
 | 
			
		||||
    _u += h*_N_;
 | 
			
		||||
    _s += h*_N_*_N_ + i*_N_;
 | 
			
		||||
 | 
			
		||||
    __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
 | 
			
		||||
    float state[_N_];
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    u[i] = float(_u[i]);
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    for (int j = 0; j < _N_; j++) {
 | 
			
		||||
        state[j] = float(_s[j]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        w[i] = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        r[i] = float(_r[t]);
 | 
			
		||||
        k[i] = float(_k[t]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float v = float(_v[t]);
 | 
			
		||||
        float y = 0;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j+=4)
 | 
			
		||||
        {
 | 
			
		||||
            const float4& r_ = (float4&)(r[j]);
 | 
			
		||||
            const float4& k_ = (float4&)(k[j]);
 | 
			
		||||
            const float4& w_ = (float4&)(w[j]);
 | 
			
		||||
            const float4& u_ = (float4&)(u[j]);
 | 
			
		||||
            float4& s = (float4&)(state[j]);
 | 
			
		||||
            float4 x;
 | 
			
		||||
 | 
			
		||||
            x.x = k_.x * v;
 | 
			
		||||
            x.y = k_.y * v;
 | 
			
		||||
            x.z = k_.z * v;
 | 
			
		||||
            x.w = k_.w * v;
 | 
			
		||||
 | 
			
		||||
            y += r_.x * (u_.x * x.x + s.x);
 | 
			
		||||
            y += r_.y * (u_.y * x.y + s.y);
 | 
			
		||||
            y += r_.z * (u_.z * x.z + s.z);
 | 
			
		||||
            y += r_.w * (u_.w * x.w + s.w);
 | 
			
		||||
 | 
			
		||||
            s.x = s.x * w_.x + x.x;
 | 
			
		||||
            s.y = s.y * w_.y + x.y;
 | 
			
		||||
            s.z = s.z * w_.z + x.z;
 | 
			
		||||
            s.w = s.w * w_.w + x.w;
 | 
			
		||||
        }
 | 
			
		||||
        _y[t] = F(y);
 | 
			
		||||
    }
 | 
			
		||||
    // #pragma unroll
 | 
			
		||||
    // for (int j = 0; j < _N_; j++)
 | 
			
		||||
    //     _s[j] = F(state[j]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
__global__ void kernel_backward_111(const int B, const int T, const int C, const int H,
 | 
			
		||||
    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy,
 | 
			
		||||
    F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu, F *__restrict__ const _gs)
 | 
			
		||||
{
 | 
			
		||||
    const int b = blockIdx.x / H;
 | 
			
		||||
    const int h = blockIdx.x % H;
 | 
			
		||||
    const int i = threadIdx.x;
 | 
			
		||||
    _u += h*_N_;
 | 
			
		||||
    _s += h*_N_*_N_ + i;
 | 
			
		||||
 | 
			
		||||
    __shared__ float u_[_N_];
 | 
			
		||||
    __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    u_[i] = float(_u[i]);
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    const float u = u_[i];
 | 
			
		||||
 | 
			
		||||
    float state[_N_], scccc[_N_] = {0}, sdddd[_N_] = {0}, sssss[_N_] = {0}, swwww[_N_];
 | 
			
		||||
    for (int j = 0; j < _N_; j++) {
 | 
			
		||||
        state[j] = float(_s[j*_N_]);
 | 
			
		||||
        swwww[j] = 1.0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const int t_0 = b*T*C + h*_N_ + i;
 | 
			
		||||
    const int t_T_1 = t_0 + (T-1)*C;
 | 
			
		||||
    const int t_T = t_0 + T*C;
 | 
			
		||||
 | 
			
		||||
    float gu = 0;
 | 
			
		||||
    for (int t = t_0; t < t_T; t += C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        v[i] = float(_v[t]);
 | 
			
		||||
        gy[i] = float(_gy[t]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float k = float(_k[t]);
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        float gr = 0, gu_ = 0;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = state[j];
 | 
			
		||||
            float x = k * v[j];
 | 
			
		||||
 | 
			
		||||
            gr += (u * x + s) * gy[j];
 | 
			
		||||
            gu_ += x * gy[j];
 | 
			
		||||
            s = s * w + x;
 | 
			
		||||
        }
 | 
			
		||||
        _gr[t] = F(gr);
 | 
			
		||||
        gu += float(_r[t]) * gu_;
 | 
			
		||||
    }
 | 
			
		||||
    _gu[b*C + h*_N_ + i] = F(gu);
 | 
			
		||||
 | 
			
		||||
    for (int t = t_T_1; t >= t_0; t -= C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        v[i] = float(_v[t]);
 | 
			
		||||
        gy[i] = float(_gy[t]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float rr = float(_r[t]);
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        float gk = 0;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = scccc[j];
 | 
			
		||||
            float x = rr * gy[j];
 | 
			
		||||
            
 | 
			
		||||
            gk += (u * x + s) * v[j];
 | 
			
		||||
            s = x + s * w;
 | 
			
		||||
        }
 | 
			
		||||
        _gk[t] = F(gk);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int t = t_T_1; t >= t_0; t -= C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        r[i] = float(_r[t]);
 | 
			
		||||
        k[i] = float(_k[t]);
 | 
			
		||||
        w_[i] = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float gyy = float(_gy[t]);
 | 
			
		||||
        float gv = 0;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = sdddd[j];
 | 
			
		||||
            float x = gyy * r[j];
 | 
			
		||||
            
 | 
			
		||||
            gv += (u_[j] * x + s) * k[j];
 | 
			
		||||
            s = x + s * w_[j];
 | 
			
		||||
        }
 | 
			
		||||
        _gv[t] = F(gv);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int t = t_0; t < t_T; t += C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        r[i] = float(_r[t]);
 | 
			
		||||
        w_[i] = __expf(-__expf(float(_w[t])));
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float gyy = float(_gy[t]);
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& w = swwww[j];
 | 
			
		||||
            sssss[j] += gyy * w * r[j];
 | 
			
		||||
            w *= w_[j];
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    for (int j = 0; j < _N_; j++)
 | 
			
		||||
        _gs[b*H*_N_*_N_ + h*_N_*_N_ + i*_N_ + j] = F(sssss[j]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename F>
 | 
			
		||||
__global__ void kernel_backward_222(const int B, const int T, const int C, const int H,
 | 
			
		||||
    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy,
 | 
			
		||||
    F *__restrict__ const _gw)
 | 
			
		||||
{
 | 
			
		||||
    const int b = blockIdx.x / H;
 | 
			
		||||
    const int h = blockIdx.x % H;
 | 
			
		||||
    const int i = threadIdx.x;
 | 
			
		||||
    _s += h*_N_*_N_ + i;
 | 
			
		||||
 | 
			
		||||
    __shared__ float v[_N_], gy[_N_];
 | 
			
		||||
    float state[_N_], saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0};
 | 
			
		||||
    for (int j = 0; j < _N_; j++) {
 | 
			
		||||
        state[j] = float(_s[j*_N_]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const int t_0 = b*T*C + h*_N_ + i;
 | 
			
		||||
    const int t_1 = t_0 + C;
 | 
			
		||||
    const int t_2 = t_0 + 2*C;
 | 
			
		||||
    const int t_T_1 = t_0 + (T-1)*C;
 | 
			
		||||
 | 
			
		||||
    for (int t = t_T_1; t > t_1; t -= C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        gy[i] = float(_gy[t]);
 | 
			
		||||
        v[i] = float(_v[t-2*C]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float r = float(_r[t]);
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t-C])));
 | 
			
		||||
        float sum = 0.0f;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = saaaa[j];
 | 
			
		||||
            s = (s + r * gy[j]) * w;
 | 
			
		||||
            sum += s * v[j];
 | 
			
		||||
        }
 | 
			
		||||
        sbbbb[(t-t_1)/C] = sum * float(_k[t-2*C]);
 | 
			
		||||
    }
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        gy[i] = float(_gy[t_1]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float r = float(_r[t_1]);
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t_0])));
 | 
			
		||||
        float sum = 0.0f;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = saaaa[j];
 | 
			
		||||
            s = (s + r * gy[j]) * w;
 | 
			
		||||
            sum += s * state[j];
 | 
			
		||||
        }
 | 
			
		||||
        sbbbb[0] = sum;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    float sss = sbbbb[0];
 | 
			
		||||
    _gw[t_0] = F(sss * -__expf(float(_w[t_0])));
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        gy[i] = float(_gy[t_1]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t_0])));
 | 
			
		||||
        float sum = 0.0f;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = scccc[j];
 | 
			
		||||
            s = (s + state[j]) * w;
 | 
			
		||||
            sum += s * gy[j];
 | 
			
		||||
        }
 | 
			
		||||
        sss += sbbbb[1] - (sum * float(_r[t_1]));
 | 
			
		||||
        _gw[t_1] = F(sss * -__expf(float(_w[t_1])));
 | 
			
		||||
    }
 | 
			
		||||
    for (int t = t_2; t < t_T_1; t += C)
 | 
			
		||||
    {
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
        gy[i] = float(_gy[t]);
 | 
			
		||||
        v[i] = float(_v[t-2*C]);
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        const float w = __expf(-__expf(float(_w[t-C])));
 | 
			
		||||
        const float k = float(_k[t-2*C]);
 | 
			
		||||
        float sum = 0.0f;
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < _N_; j++)
 | 
			
		||||
        {
 | 
			
		||||
            float& s = scccc[j];
 | 
			
		||||
            s = (s + k * v[j]) * w;
 | 
			
		||||
            sum += s * gy[j];
 | 
			
		||||
        }
 | 
			
		||||
        sss += sbbbb[(t-t_0)/C] - (sum * float(_r[t]));
 | 
			
		||||
        _gw[t] = F(sss * -__expf(float(_w[t])));
 | 
			
		||||
    }
 | 
			
		||||
    _gw[t_T_1] = 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *y)
 | 
			
		||||
{
 | 
			
		||||
    assert(H*_N_ == C);
 | 
			
		||||
    assert(_N_%4 == 0);
 | 
			
		||||
    kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, y);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs)
 | 
			
		||||
{
 | 
			
		||||
    assert(H*_N_ == C);
 | 
			
		||||
    assert(_N_%4 == 0);
 | 
			
		||||
    kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, gy, gr, gk, gv, gu, gs);
 | 
			
		||||
    kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, z, gy, gw);
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										22
									
								
								finetune/lora/v6/cuda/wkv6state_op.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								finetune/lora/v6/cuda/wkv6state_op.cpp
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include "ATen/ATen.h"
 | 
			
		||||
typedef at::BFloat16 bf16;
 | 
			
		||||
 | 
			
		||||
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y);
 | 
			
		||||
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs);
 | 
			
		||||
 | 
			
		||||
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) {
 | 
			
		||||
    cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), y.data_ptr<bf16>());
 | 
			
		||||
}
 | 
			
		||||
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) {
 | 
			
		||||
    cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gs.data_ptr<bf16>());
 | 
			
		||||
}
 | 
			
		||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
    m.def("forward", &forward, "wkv6state forward");
 | 
			
		||||
    m.def("backward", &backward, "wkv6state backward");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY(wkv6state, m) {
 | 
			
		||||
    m.def("forward", forward);
 | 
			
		||||
    m.def("backward", backward);
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										16
									
								
								finetune/lora/v6/demo/demo-lora-merge.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								finetune/lora/v6/demo/demo-lora-merge.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
			
		||||
 | 
			
		||||
base_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth'
 | 
			
		||||
lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth'
 | 
			
		||||
lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth'
 | 
			
		||||
output='/home/rwkv/JL/model/nf4-world.pth'
 | 
			
		||||
QUANT='nf4' #follow train
 | 
			
		||||
TYPE='lora'
 | 
			
		||||
Lora_alpha=128
 | 
			
		||||
 | 
			
		||||
python merge/merge.py --base_model $base_model \
 | 
			
		||||
--lora_init $lora_init \
 | 
			
		||||
--lora_checkpoint $lora_checkpoint \
 | 
			
		||||
--output $output \
 | 
			
		||||
--quant $QUANT \
 | 
			
		||||
--type $TYPE \
 | 
			
		||||
--lora_alpha $Lora_alpha
 | 
			
		||||
							
								
								
									
										27
									
								
								finetune/lora/v6/demo/demo-lora.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								finetune/lora/v6/demo/demo-lora.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,27 @@
 | 
			
		||||
load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth'
 | 
			
		||||
proj_dir='/home/rwkv/JL/out_model/nf4'
 | 
			
		||||
data_file='/home/rwkv/JL/data/roleplay'
 | 
			
		||||
 | 
			
		||||
QUANT='nf4'  #4bit nf4 fp4 none
 | 
			
		||||
 | 
			
		||||
lora_r=64
 | 
			
		||||
lora_alpha=128
 | 
			
		||||
 | 
			
		||||
n_layer=32
 | 
			
		||||
n_embd=4096
 | 
			
		||||
 | 
			
		||||
micro_bsz=8
 | 
			
		||||
epoch_save=1
 | 
			
		||||
epoch_steps=1000
 | 
			
		||||
ctx_len=1024
 | 
			
		||||
 | 
			
		||||
python train.py --load_model $load_model \
 | 
			
		||||
--proj_dir $proj_dir --data_file $data_file \
 | 
			
		||||
--data_type binidx --vocab_size 65536 \
 | 
			
		||||
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
 | 
			
		||||
--n_layer $n_layer --n_embd $n_embd \
 | 
			
		||||
--pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
 | 
			
		||||
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
 | 
			
		||||
--my_testing "x060" \
 | 
			
		||||
--lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha $lora_alpha --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
 | 
			
		||||
--quant $QUANT
 | 
			
		||||
							
								
								
									
										15
									
								
								finetune/lora/v6/demo/demo-pissa-merge.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								finetune/lora/v6/demo/demo-pissa-merge.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,15 @@
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
base_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2-20240208-ctx4096.pth'
 | 
			
		||||
lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth'
 | 
			
		||||
lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth'
 | 
			
		||||
output='/home/rwkv/JL/model/end-world.pth'
 | 
			
		||||
QUANT='nf4' #follow train
 | 
			
		||||
TYPE='pissa'
 | 
			
		||||
 | 
			
		||||
python merge/merge.py --base_model $base_model \
 | 
			
		||||
--lora_init $lora_init \
 | 
			
		||||
--lora_checkpoint $lora_checkpoint \
 | 
			
		||||
--output $output \
 | 
			
		||||
--quant $QUANT \
 | 
			
		||||
--type $TYPE 
 | 
			
		||||
							
								
								
									
										40
									
								
								finetune/lora/v6/demo/demo-pissa.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								finetune/lora/v6/demo/demo-pissa.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,40 @@
 | 
			
		||||
 | 
			
		||||
load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth'
 | 
			
		||||
proj_dir='/home/rwkv/JL/out_model/nf4'
 | 
			
		||||
data_file='/home/rwkv/JL/data/end_text_document'
 | 
			
		||||
 | 
			
		||||
QUANT='nf4'  #4bit nf4 fp4 none
 | 
			
		||||
svd_niter=4  
 | 
			
		||||
lora_r=64
 | 
			
		||||
 | 
			
		||||
n_layer=24
 | 
			
		||||
n_embd=2048
 | 
			
		||||
 | 
			
		||||
micro_bsz=8
 | 
			
		||||
epoch_save=1
 | 
			
		||||
epoch_steps=1000
 | 
			
		||||
ctx_len=1024
 | 
			
		||||
 | 
			
		||||
python train.py --load_model $load_model \
 | 
			
		||||
--proj_dir $proj_dir --data_file $data_file \
 | 
			
		||||
--data_type binidx --vocab_size 65536 \
 | 
			
		||||
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
 | 
			
		||||
--n_layer $n_layer --n_embd $n_embd \
 | 
			
		||||
--pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
 | 
			
		||||
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
 | 
			
		||||
--my_testing "x060" \
 | 
			
		||||
--lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
 | 
			
		||||
--PISSA --svd_niter $svd_niter \
 | 
			
		||||
--dataload pad
 | 
			
		||||
 | 
			
		||||
###remove load_model
 | 
			
		||||
# python train.py --proj_dir $proj_dir --data_file $data_file \
 | 
			
		||||
# --data_type binidx --vocab_size 65536 \
 | 
			
		||||
# --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
 | 
			
		||||
# --n_layer $n_layer --n_embd $n_embd \
 | 
			
		||||
# --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
 | 
			
		||||
# --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
 | 
			
		||||
# --my_testing "x060" \
 | 
			
		||||
# --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
 | 
			
		||||
# --PISSA --svd_niter $svd_niter \
 | 
			
		||||
# --quant $QUANT
 | 
			
		||||
							
								
								
									
										27
									
								
								finetune/lora/v6/demo/demo-qpissa-pt.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								finetune/lora/v6/demo/demo-qpissa-pt.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,27 @@
 | 
			
		||||
load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth'
 | 
			
		||||
proj_dir='/home/rwkv/JL/out_model/nf4'
 | 
			
		||||
data_file='/home/rwkv/JL/data/roleplay'
 | 
			
		||||
 | 
			
		||||
QUANT='nf4'  #4bit nf4 fp4 none
 | 
			
		||||
svd_niter=4  
 | 
			
		||||
lora_r=64
 | 
			
		||||
 | 
			
		||||
n_layer=32
 | 
			
		||||
n_embd=4096
 | 
			
		||||
 | 
			
		||||
micro_bsz=4
 | 
			
		||||
epoch_save=1
 | 
			
		||||
epoch_steps=1000
 | 
			
		||||
ctx_len=1024
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
python train.py --proj_dir $proj_dir --data_file $data_file \
 | 
			
		||||
--data_type binidx --vocab_size 65536 \
 | 
			
		||||
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
 | 
			
		||||
--n_layer $n_layer --n_embd $n_embd \
 | 
			
		||||
--pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
 | 
			
		||||
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
 | 
			
		||||
--my_testing "x060" \
 | 
			
		||||
--lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
 | 
			
		||||
--PISSA --svd_niter $svd_niter \
 | 
			
		||||
--quant $QUANT
 | 
			
		||||
							
								
								
									
										8
									
								
								finetune/lora/v6/demo/demo-state-merge.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								finetune/lora/v6/demo/demo-state-merge.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,8 @@
 | 
			
		||||
base_model='/home/rwkv/JL/model/RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth'
 | 
			
		||||
state_checkpoint='/home/rwkv/JL/out_model/state/rwkv-9.pth'
 | 
			
		||||
output='/home/rwkv/JL/model/state-0.pth'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
python merge/merge_state.py --base_model $base_model \
 | 
			
		||||
--state_checkpoint $state_checkpoint \
 | 
			
		||||
--output $output
 | 
			
		||||
							
								
								
									
										22
									
								
								finetune/lora/v6/demo/demo-state-tuning.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								finetune/lora/v6/demo/demo-state-tuning.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
			
		||||
load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth'
 | 
			
		||||
proj_dir='/home/rwkv/JL/out_model/state'
 | 
			
		||||
data_file='/home/rwkv/JL/data/end_text_document'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
n_layer=24
 | 
			
		||||
n_embd=2048
 | 
			
		||||
 | 
			
		||||
micro_bsz=1
 | 
			
		||||
epoch_save=1
 | 
			
		||||
epoch_steps=1000
 | 
			
		||||
ctx_len=1024
 | 
			
		||||
 | 
			
		||||
python train.py --load_model $load_model \
 | 
			
		||||
--proj_dir $proj_dir --data_file $data_file \
 | 
			
		||||
--data_type binidx --vocab_size 65536 \
 | 
			
		||||
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
 | 
			
		||||
--n_layer $n_layer --n_embd $n_embd \
 | 
			
		||||
--pre_ffn 0 --head_qk 0 --lr_init 1 --lr_final 1e-1 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
 | 
			
		||||
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 0 \
 | 
			
		||||
--my_testing "x060" \
 | 
			
		||||
--train_type "state"  --dataload pad --wandb fla --fla
 | 
			
		||||
							
								
								
									
										27
									
								
								finetune/lora/v6/demo/demo-training-prepare.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								finetune/lora/v6/demo/demo-training-prepare.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,27 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
# Create data directory
 | 
			
		||||
 | 
			
		||||
mkdir -p data
 | 
			
		||||
 | 
			
		||||
# Download minipile (1498226207 tokens, around 3GB)
 | 
			
		||||
 | 
			
		||||
wget --continue -O data/minipile.idx https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.idx
 | 
			
		||||
wget --continue -O data/minipile.bin https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.bin
 | 
			
		||||
 | 
			
		||||
# Generate initial model (L12-D768 = 169M)
 | 
			
		||||
 | 
			
		||||
BASE_NAME="model/0.1-1"
 | 
			
		||||
N_LAYER="12"
 | 
			
		||||
N_EMBD="768"
 | 
			
		||||
 | 
			
		||||
# magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)
 | 
			
		||||
# use https://www.dcode.fr/prime-numbers-search
 | 
			
		||||
 | 
			
		||||
python train.py --wandb "" --proj_dir $BASE_NAME \
 | 
			
		||||
 --data_file "data/minipile" --data_type "binidx" --vocab_size 65536 \
 | 
			
		||||
 --ctx_len 512 --my_pile_stage 1 --epoch_count 1 --epoch_begin 0 \
 | 
			
		||||
 --epoch_save 1 --weight_decay 0 --head_size_a 64 \
 | 
			
		||||
 --num_nodes 1 --micro_bsz 1 --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 --my_exit_tokens 1498226207 --magic_prime 2926181 \
 | 
			
		||||
 --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 \
 | 
			
		||||
 --accelerator cpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 --enable_progress_bar False --ds_bucket_mb 200
 | 
			
		||||
							
								
								
									
										21
									
								
								finetune/lora/v6/demo/demo-training-run.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								finetune/lora/v6/demo/demo-training-run.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,21 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
BASE_NAME="model/0.1-1"
 | 
			
		||||
N_LAYER="12"
 | 
			
		||||
N_EMBD="768"
 | 
			
		||||
M_BSZ="16" # takes 16G VRAM (reduce this to save VRAM)
 | 
			
		||||
LR_INIT="6e-4"
 | 
			
		||||
LR_FINAL="6e-5"
 | 
			
		||||
GRAD_CP=0 # set to 1 to save VRAM (will be slower)
 | 
			
		||||
EPOCH_SAVE=10
 | 
			
		||||
 | 
			
		||||
# magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)
 | 
			
		||||
# use https://www.dcode.fr/prime-numbers-search
 | 
			
		||||
 | 
			
		||||
python train.py --load_model "0" --wandb "RWKV-5-Test" --proj_dir $BASE_NAME \
 | 
			
		||||
 --ctx_len 512 --my_pile_stage 3 --epoch_count 999999 --epoch_begin 0 \
 | 
			
		||||
 --data_file "data/minipile" --my_exit_tokens 1498226207 --magic_prime 2926181 \
 | 
			
		||||
 --num_nodes 1 --micro_bsz $M_BSZ --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 \
 | 
			
		||||
 --lr_init $LR_INIT --lr_final $LR_FINAL --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 --data_type "binidx" --vocab_size 65536 \
 | 
			
		||||
 --weight_decay 0.001 --epoch_save $EPOCH_SAVE --head_size_a 64 \
 | 
			
		||||
 --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp $GRAD_CP --enable_progress_bar True --ds_bucket_mb 200
 | 
			
		||||
							
								
								
									
										182
									
								
								finetune/lora/v6/demo/demo.jsonl
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										182
									
								
								finetune/lora/v6/demo/demo.jsonl
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										25
									
								
								finetune/lora/v6/demo/infctx.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								finetune/lora/v6/demo/infctx.sh
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,25 @@
 | 
			
		||||
load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth'
 | 
			
		||||
proj_dir='/home/rwkv/JL/out_model/infctx'
 | 
			
		||||
data_file='/home/rwkv/JL/data/roleplay'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
n_layer=24
 | 
			
		||||
n_embd=2048
 | 
			
		||||
 | 
			
		||||
micro_bsz=8
 | 
			
		||||
epoch_save=5
 | 
			
		||||
epoch_steps=1000
 | 
			
		||||
ctx_len=16384
 | 
			
		||||
chunk_ctx=2048
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
python train.py --load_model $load_model \
 | 
			
		||||
--proj_dir $proj_dir --data_file $data_file \
 | 
			
		||||
--data_type binidx --vocab_size 65536 \
 | 
			
		||||
--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
 | 
			
		||||
--n_layer $n_layer --n_embd $n_embd \
 | 
			
		||||
--pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
 | 
			
		||||
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
 | 
			
		||||
--lora_load rwkv-0 --lora --lora_r 64 --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
 | 
			
		||||
--my_testing "x060"  --dataload pad \
 | 
			
		||||
--train_type infctx --chunk_ctx $chunk_ctx --fla --wandb infctx
 | 
			
		||||
							
								
								
									
										50
									
								
								finetune/lora/v6/fla/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								finetune/lora/v6/fla/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,50 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from fla.layers import (ABCAttention, BasedLinearAttention, DeltaNet,
 | 
			
		||||
                        GatedLinearAttention, HGRN2Attention, LinearAttention,
 | 
			
		||||
                        MultiScaleRetention, ReBasedLinearAttention)
 | 
			
		||||
from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM,
 | 
			
		||||
                        DeltaNetModel, GLAForCausalLM, GLAModel,
 | 
			
		||||
                        HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM,
 | 
			
		||||
                        HGRNModel, LinearAttentionForCausalLM,
 | 
			
		||||
                        LinearAttentionModel, RetNetForCausalLM, RetNetModel,
 | 
			
		||||
                        RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM,
 | 
			
		||||
                        TransformerModel)
 | 
			
		||||
from fla.ops import (chunk_gla, chunk_retention, fused_chunk_based,
 | 
			
		||||
                     fused_chunk_gla, fused_chunk_retention)
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'ABCAttention',
 | 
			
		||||
    'BasedLinearAttention',
 | 
			
		||||
    'DeltaNet',
 | 
			
		||||
    'HGRN2Attention',
 | 
			
		||||
    'GatedLinearAttention',
 | 
			
		||||
    'LinearAttention',
 | 
			
		||||
    'MultiScaleRetention',
 | 
			
		||||
    'ReBasedLinearAttention',
 | 
			
		||||
    'ABCForCausalLM',
 | 
			
		||||
    'ABCModel',
 | 
			
		||||
    'DeltaNetForCausalLM',
 | 
			
		||||
    'DeltaNetModel',
 | 
			
		||||
    'HGRNForCausalLM',
 | 
			
		||||
    'HGRNModel',
 | 
			
		||||
    'HGRN2ForCausalLM',
 | 
			
		||||
    'HGRN2Model',
 | 
			
		||||
    'GLAForCausalLM',
 | 
			
		||||
    'GLAModel',
 | 
			
		||||
    'LinearAttentionForCausalLM',
 | 
			
		||||
    'LinearAttentionModel',
 | 
			
		||||
    'RetNetForCausalLM',
 | 
			
		||||
    'RetNetModel',
 | 
			
		||||
    'RWKV6ForCausalLM',
 | 
			
		||||
    'RWKV6Model',
 | 
			
		||||
    'TransformerForCausalLM',
 | 
			
		||||
    'TransformerModel',
 | 
			
		||||
    'chunk_gla',
 | 
			
		||||
    'chunk_retention',
 | 
			
		||||
    'fused_chunk_based',
 | 
			
		||||
    'fused_chunk_gla',
 | 
			
		||||
    'fused_chunk_retention'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
__version__ = '0.1'
 | 
			
		||||
							
								
								
									
										25
									
								
								finetune/lora/v6/fla/layers/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								finetune/lora/v6/fla/layers/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,25 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .abc import ABCAttention
 | 
			
		||||
from .based import BasedLinearAttention
 | 
			
		||||
from .delta_net import DeltaNet
 | 
			
		||||
from .gla import GatedLinearAttention
 | 
			
		||||
from .hgrn import HGRNAttention
 | 
			
		||||
from .hgrn2 import HGRN2Attention
 | 
			
		||||
from .linear_attn import LinearAttention
 | 
			
		||||
from .multiscale_retention import MultiScaleRetention
 | 
			
		||||
from .rebased import ReBasedLinearAttention
 | 
			
		||||
from .rwkv6 import RWKV6Attention
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'ABCAttention',
 | 
			
		||||
    'BasedLinearAttention',
 | 
			
		||||
    'DeltaNet',
 | 
			
		||||
    'GatedLinearAttention',
 | 
			
		||||
    'HGRNAttention',
 | 
			
		||||
    'HGRN2Attention',
 | 
			
		||||
    'LinearAttention',
 | 
			
		||||
    'MultiScaleRetention',
 | 
			
		||||
    'ReBasedLinearAttention',
 | 
			
		||||
    'RWKV6Attention'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										195
									
								
								finetune/lora/v6/fla/layers/abc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										195
									
								
								finetune/lora/v6/fla/layers/abc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,195 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
from fla.modules import (FusedRMSNormSwishGate, RMSNorm, RotaryEmbedding,
 | 
			
		||||
                         ShortConvolution)
 | 
			
		||||
from fla.modules.activations import swiglu, swish
 | 
			
		||||
from fla.modules.convolution import proj_then_conv1d
 | 
			
		||||
from fla.ops.abc.chunk import chunk_abc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ABCAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int = 1024,
 | 
			
		||||
        expand_k: float = 0.5,
 | 
			
		||||
        expand_v: float = 1.0,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        conv_bias: bool = False,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        num_slots: Optional[int] = None,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-5,
 | 
			
		||||
        gate_low_rank_dim: int = 16,
 | 
			
		||||
        gate_logit_normalizer: int = 16,
 | 
			
		||||
        use_input_gate: bool = False,
 | 
			
		||||
        use_output_gate: bool = True,
 | 
			
		||||
        use_norm: bool = True,
 | 
			
		||||
        clamp_min: Optional[float] = -32,
 | 
			
		||||
        clamp_max: Optional[float] = 32,
 | 
			
		||||
        layer_idx: Optional[int] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> ABCAttention:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.key_dim = int(self.hidden_size * self.expand_k)
 | 
			
		||||
        self.value_dim = int(self.hidden_size * self.expand_v)
 | 
			
		||||
        self.head_k_dim = self.key_dim // self.num_heads
 | 
			
		||||
        self.head_v_dim = self.value_dim // self.num_heads
 | 
			
		||||
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.conv_bias = conv_bias
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
 | 
			
		||||
        self.gate_low_rank_dim = gate_low_rank_dim
 | 
			
		||||
        self.gate_logit_normalizer = gate_logit_normalizer
 | 
			
		||||
 | 
			
		||||
        self.use_input_gate = use_input_gate
 | 
			
		||||
        self.use_output_gate = use_output_gate
 | 
			
		||||
        self.use_norm = use_norm
 | 
			
		||||
 | 
			
		||||
        if num_slots is None:
 | 
			
		||||
            num_slots = self.head_k_dim
 | 
			
		||||
        self.num_slots = num_slots
 | 
			
		||||
 | 
			
		||||
        self.norm_eps = norm_eps
 | 
			
		||||
 | 
			
		||||
        self.clamp_min = clamp_min
 | 
			
		||||
        self.clamp_max = clamp_max
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        if layer_idx is None:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
 | 
			
		||||
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
 | 
			
		||||
                "when creating this class."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
 | 
			
		||||
 | 
			
		||||
        if use_output_gate:
 | 
			
		||||
            self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
 | 
			
		||||
        self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
 | 
			
		||||
        self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        if use_short_conv:
 | 
			
		||||
            self.conv_size = conv_size
 | 
			
		||||
            if share_conv_kernel:
 | 
			
		||||
                self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
 | 
			
		||||
            else:
 | 
			
		||||
                self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
 | 
			
		||||
                self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
 | 
			
		||||
                self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
 | 
			
		||||
 | 
			
		||||
        if self.use_norm:
 | 
			
		||||
            if self.use_output_gate:
 | 
			
		||||
                self.g_norm = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
            else:
 | 
			
		||||
                self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
 | 
			
		||||
        if self.use_rope:
 | 
			
		||||
            self.rotary = RotaryEmbedding(self.head_k_dim)
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                hidden_states = self.h_conv1d(hidden_states)
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                k = self.k_proj(hidden_states)
 | 
			
		||||
                v = self.v_proj(hidden_states)
 | 
			
		||||
            else:
 | 
			
		||||
                q = proj_then_conv1d(hidden_states, self.q_proj.weight, self.q_conv1d.weight, self.q_conv1d.bias)
 | 
			
		||||
                k = proj_then_conv1d(hidden_states, self.k_proj.weight, self.k_conv1d.weight, self.k_conv1d.bias)
 | 
			
		||||
                v = proj_then_conv1d(hidden_states, self.v_proj.weight, self.v_conv1d.weight, self.v_conv1d.bias)
 | 
			
		||||
        else:
 | 
			
		||||
            q = self.q_proj(hidden_states)
 | 
			
		||||
            k = self.k_proj(hidden_states)
 | 
			
		||||
            v = self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
        if self.use_input_gate:
 | 
			
		||||
            q, k, v = map(lambda x: swish(x), (q, k, v))
 | 
			
		||||
 | 
			
		||||
        if self.use_rope:
 | 
			
		||||
            q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
 | 
			
		||||
            k = rearrange(k, '... (h d) -> ... h d', h=self.num_heads)
 | 
			
		||||
            seqlen_offset = 0
 | 
			
		||||
            if past_key_values is not None:
 | 
			
		||||
                seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
 | 
			
		||||
            q, k = self.rotary(q, k, seqlen_offset)
 | 
			
		||||
            q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads)
 | 
			
		||||
            k = rearrange(k, 'b n h d -> b h n d', h=self.num_heads)
 | 
			
		||||
        else:
 | 
			
		||||
            q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
            k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
 | 
			
		||||
        # [batch_size, n_heads, seq_len, num_slots]
 | 
			
		||||
        s = rearrange(self.s_proj(hidden_states), 'b t (h m) -> b h t m', h=self.num_heads)
 | 
			
		||||
        s = s.clamp_(self.clamp_min, self.clamp_max)
 | 
			
		||||
 | 
			
		||||
        last_state = past_key_values[self.layer_idx] if use_cache else None
 | 
			
		||||
        o, last_state = chunk_abc(q, k, v, s, initial_state=last_state, output_final_state=use_cache)
 | 
			
		||||
        if past_key_values is not None and last_state is not None:
 | 
			
		||||
            past_key_values.update(last_state, self.layer_idx, q.shape[2])
 | 
			
		||||
 | 
			
		||||
        o = rearrange(o, 'b h t d -> b t h d')
 | 
			
		||||
        if self.use_norm and not self.use_output_gate:
 | 
			
		||||
            o = self.g_norm(o)
 | 
			
		||||
        elif self.use_output_gate:
 | 
			
		||||
            g = rearrange(self.g_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_heads)
 | 
			
		||||
            o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
 | 
			
		||||
        o = rearrange(o, 'b t h d -> b t (h d)')
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
 | 
			
		||||
        return o, None, past_key_values
 | 
			
		||||
 | 
			
		||||
    def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
 | 
			
		||||
        param = next(self.parameters())
 | 
			
		||||
        state = tuple()
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
 | 
			
		||||
        state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots),
 | 
			
		||||
                  param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim))
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def state_size(self, sequence_length: int = 2048):
 | 
			
		||||
        return self.num_heads * self.key_dim * self.head_v_dim
 | 
			
		||||
							
								
								
									
										126
									
								
								finetune/lora/v6/fla/layers/based.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								finetune/lora/v6/fla/layers/based.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,126 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
Linear attention in Based.
 | 
			
		||||
https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
 | 
			
		||||
from fla.modules.feature_map import TaylorFeatureMap
 | 
			
		||||
from fla.ops.based import parallel_based
 | 
			
		||||
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BasedLinearAttention(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        l_max: int = 2048,
 | 
			
		||||
        feature_dim: int = 16,
 | 
			
		||||
        num_key_value_heads: int = 12,
 | 
			
		||||
        num_heads: int = 12,
 | 
			
		||||
        feature_name: str = "taylor_exp",
 | 
			
		||||
        eps: float = 1e-12,
 | 
			
		||||
        causal: bool = True,
 | 
			
		||||
        mode: str = "parallel",
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size
 | 
			
		||||
        self.l_max = l_max
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        assert self.mode in ["fused_chunk", "parallel", 'chunk']
 | 
			
		||||
 | 
			
		||||
        # linear attention
 | 
			
		||||
        self.feature_name = feature_name
 | 
			
		||||
        self.feature_dim = feature_dim
 | 
			
		||||
        self.num_key_value_heads = num_key_value_heads
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.head_dim = self.hidden_size // self.num_key_value_heads
 | 
			
		||||
        self.causal = causal
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
 | 
			
		||||
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 | 
			
		||||
        self.dropout = nn.Identity()
 | 
			
		||||
        self.feature_map = TaylorFeatureMap(feature_dim)
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states: torch.Tensor, **kwargs):
 | 
			
		||||
        mode = self.mode
 | 
			
		||||
        q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
 | 
			
		||||
        q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
 | 
			
		||||
        if mode == "fused_chunk":
 | 
			
		||||
            q, k = self.feature_map(q), self.feature_map(k)
 | 
			
		||||
            o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
 | 
			
		||||
        elif mode == 'chunk':
 | 
			
		||||
            q, k = self.feature_map(q), self.feature_map(k)
 | 
			
		||||
            o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
 | 
			
		||||
        elif mode == 'parallel':
 | 
			
		||||
            assert q.shape[-1] <= 128
 | 
			
		||||
            o = parallel_based(q, k, v, True, True)
 | 
			
		||||
        o = rearrange(o, "b h l d -> b l (h d)")
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
        o = self.dropout(o)
 | 
			
		||||
        return o
 | 
			
		||||
 | 
			
		||||
    # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
 | 
			
		||||
 | 
			
		||||
    def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        x (torch.Tensor): tensor of shape (b, d, l)
 | 
			
		||||
        y (torch.Tensor): tensor of shape (b, d, l)
 | 
			
		||||
        """
 | 
			
		||||
        # hidden_states = hidden_states.transpose(1, 2)
 | 
			
		||||
        b, l, _ = hidden_states.size()
 | 
			
		||||
        q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
        q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
 | 
			
		||||
        k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
 | 
			
		||||
        v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        # Linear attention
 | 
			
		||||
        q, k = self.feature_map(q), self.feature_map(k)
 | 
			
		||||
        q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
 | 
			
		||||
 | 
			
		||||
        # Compute attention
 | 
			
		||||
        if self.causal:
 | 
			
		||||
            y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
 | 
			
		||||
        else:
 | 
			
		||||
            y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
 | 
			
		||||
        y = rearrange(y, 'b h l d -> b l (h d)')
 | 
			
		||||
        y = self.o_proj(y.to(hidden_states.dtype))
 | 
			
		||||
        y = self.dropout(y)
 | 
			
		||||
        return y.to(hidden_states.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    batch = 4
 | 
			
		||||
    seq_len = 1024
 | 
			
		||||
    hidden_size = 1024
 | 
			
		||||
    dtype = torch.float32
 | 
			
		||||
    x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
 | 
			
		||||
    dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
 | 
			
		||||
    model = BasedLinearAttention(hidden_size, mode='chunk').to(dtype).cuda()
 | 
			
		||||
    y = model(x)
 | 
			
		||||
    y.backward(dy, retain_graph=True)
 | 
			
		||||
    x_grad, x.grad = x.grad, None
 | 
			
		||||
    y2 = model.forward_reference(x)
 | 
			
		||||
    y2.backward(dy)
 | 
			
		||||
    assert y.allclose(y2, 0, 1e-4), breakpoint()
 | 
			
		||||
    assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
 | 
			
		||||
    print("Pass")
 | 
			
		||||
							
								
								
									
										254
									
								
								finetune/lora/v6/fla/layers/delta_net.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								finetune/lora/v6/fla/layers/delta_net.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,254 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Sect4.2 of Linear Transformers Are Secretly Fast Weight Programmers https://arxiv.org/abs/2102.11174
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution, LayerNorm
 | 
			
		||||
from fla.modules.rotary import RotaryEmbedding
 | 
			
		||||
from fla.ops.delta_rule import (fused_chunk_delta_rule,
 | 
			
		||||
                                fused_recurrent_linear_attn_delta_rule,
 | 
			
		||||
                                chunk_delta_rule)
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def simple_norm(x):
 | 
			
		||||
    return (F.normalize(x, dim=-1) * x.shape[-1] ** 0.5).to(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.jit.script
 | 
			
		||||
def elu_p1(x):
 | 
			
		||||
    return (F.elu(x, 1., False) + 1.).to(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.jit.script
 | 
			
		||||
def sum_norm(x):
 | 
			
		||||
    return (x / x.sum(-1, keepdim=True)).to(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @torch.jit.script
 | 
			
		||||
def elu_norm(x):
 | 
			
		||||
    dtype = x.dtype
 | 
			
		||||
    x = F.elu(x, 1., False) + 1.
 | 
			
		||||
    return (x / x.sum(-1, keepdim=True)).to(dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
 | 
			
		||||
class DeltaNet(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        d_model: int = None,
 | 
			
		||||
        hidden_size: int = 1024,
 | 
			
		||||
        expand_k: float = 1.0,
 | 
			
		||||
        expand_v: float = 1.0,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        mode: str = 'fused_chunk',
 | 
			
		||||
        chunk_size: int = 16,
 | 
			
		||||
        use_beta: bool = True,
 | 
			
		||||
        use_gate: bool = True,
 | 
			
		||||
        use_rope: bool = False,
 | 
			
		||||
        use_output_norm: bool = True,
 | 
			
		||||
        use_elu: bool = False,
 | 
			
		||||
        use_short_conv: bool = True,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        conv_bias: bool = False,
 | 
			
		||||
        share_conv_kernel: bool = False,
 | 
			
		||||
        layer_idx: int = None,
 | 
			
		||||
        qk_activation: str = 'silu',
 | 
			
		||||
        qk_norm: str = None,
 | 
			
		||||
        save_memory: str = False,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> DeltaNet:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        self.qk_activation = qk_activation
 | 
			
		||||
        self.qk_norm = qk_norm
 | 
			
		||||
        assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
 | 
			
		||||
        assert self.qk_norm in ['l2', 'sum']
 | 
			
		||||
        if d_model is not None:
 | 
			
		||||
            hidden_size = d_model
 | 
			
		||||
        self.hidden_size = hidden_size    
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.chunk_size = chunk_size
 | 
			
		||||
        self.use_gate = use_gate
 | 
			
		||||
        self.use_output_norm = use_output_norm
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.conv_bias = conv_bias
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
 | 
			
		||||
        self.key_dim = int(hidden_size * expand_k)
 | 
			
		||||
        self.value_dim = int(hidden_size * expand_v)
 | 
			
		||||
        self.head_qk_dim = self.key_dim // num_heads
 | 
			
		||||
        self.head_v_dim = self.value_dim // num_heads
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        self.silu = torch.nn.SiLU()
 | 
			
		||||
 | 
			
		||||
        assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
 | 
			
		||||
        assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
        assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
 | 
			
		||||
        self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        self.use_beta = use_beta
 | 
			
		||||
        self.use_elu = use_elu
 | 
			
		||||
        if self.use_beta:
 | 
			
		||||
            self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
 | 
			
		||||
        if use_short_conv:
 | 
			
		||||
            self.conv_size = conv_size
 | 
			
		||||
            if share_conv_kernel:
 | 
			
		||||
                self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation=None)
 | 
			
		||||
            else:
 | 
			
		||||
                self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
 | 
			
		||||
                self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
 | 
			
		||||
                self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
 | 
			
		||||
        if use_gate:
 | 
			
		||||
            self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
 | 
			
		||||
        if self.use_gate:
 | 
			
		||||
            self.norm = FusedRMSNormSwishGate(self.head_v_dim)
 | 
			
		||||
        else:
 | 
			
		||||
            self.norm = RMSNorm(self.head_v_dim)
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
 | 
			
		||||
        # change to inference mode.
 | 
			
		||||
        mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode    
 | 
			
		||||
        last_state = past_key_values[self.layer_idx] if use_cache else None
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            if attention_mask.shape[-1] != hidden_states.shape[-2]:
 | 
			
		||||
                attention_mask = attention_mask[:, -1:]
 | 
			
		||||
        
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            conv_state = last_state[0] if use_cache else None
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                # conv state is updated inplace
 | 
			
		||||
                hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                k = self.k_proj(hidden_states)
 | 
			
		||||
                v = self.v_proj(hidden_states)
 | 
			
		||||
            else:
 | 
			
		||||
                conv_state_q = last_state[0] if use_cache else None
 | 
			
		||||
                conv_state_k = last_state[1] if use_cache else None
 | 
			
		||||
                conv_state_v = last_state[2] if use_cache else None
 | 
			
		||||
                k = self.k_proj(hidden_states)
 | 
			
		||||
                v = self.v_proj(hidden_states)
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                q = self.q_conv1d(q, attention_mask, conv_state_q)
 | 
			
		||||
                k = self.k_conv1d(k, attention_mask, conv_state_k)
 | 
			
		||||
                v = self.v_conv1d(v, attention_mask, conv_state_v)
 | 
			
		||||
        else:
 | 
			
		||||
            q = (self.q_proj(hidden_states))
 | 
			
		||||
            k = (self.k_proj(hidden_states))
 | 
			
		||||
            v = self.silu(self.v_proj(hidden_states))
 | 
			
		||||
 | 
			
		||||
        # dealing with left-padding
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            v = v.mul_(attention_mask.unsqueeze(-1))
 | 
			
		||||
        
 | 
			
		||||
        q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v))
 | 
			
		||||
 | 
			
		||||
        if self.qk_activation != 'silu':
 | 
			
		||||
            if self.qk_activation == 'relu':
 | 
			
		||||
                q, k = q.relu(), k.relu()
 | 
			
		||||
            elif self.qk_activation == 'elu':
 | 
			
		||||
                q, k = elu_p1(q), elu_p1(k)
 | 
			
		||||
            elif self.qk_activation == 'identity':
 | 
			
		||||
                pass
 | 
			
		||||
            else:
 | 
			
		||||
                raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        if self.qk_norm is not None:
 | 
			
		||||
            if self.qk_norm == 'l2':
 | 
			
		||||
                k = torch.nn.functional.normalize(k, dim=-1, p=2).to(v) #auto mixed precision type transfer is annoying.
 | 
			
		||||
                q = torch.nn.functional.normalize(q, dim=-1, p=2).to(v)
 | 
			
		||||
            elif self.qk_norm == 'sum':
 | 
			
		||||
                q = sum_norm(q).to(v)
 | 
			
		||||
                k = sum_norm(k).to(v)      
 | 
			
		||||
            
 | 
			
		||||
        if self.use_beta:
 | 
			
		||||
            beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid()
 | 
			
		||||
        else:
 | 
			
		||||
            beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
 | 
			
		||||
        state = past_key_values[self.layer_idx][-1] if use_cache else None
 | 
			
		||||
        if mode == 'fused_recurrent':
 | 
			
		||||
            o, recurrent_state = fused_recurrent_linear_attn_delta_rule(q, k, v, beta, state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'fused_chunk':
 | 
			
		||||
            assert self.chunk_size in [16, 32, 64]
 | 
			
		||||
            o, recurrent_state = fused_chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'chunk':
 | 
			
		||||
            assert self.chunk_size in [16, 32, 64]
 | 
			
		||||
            o, recurrent_state = chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Not supported mode `{mode}`.")
 | 
			
		||||
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if self.use_short_conv:
 | 
			
		||||
                if self.share_conv_kernel:
 | 
			
		||||
                    state = (conv_state, recurrent_state)
 | 
			
		||||
                else:
 | 
			
		||||
                    state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
 | 
			
		||||
            else:
 | 
			
		||||
                state = (recurrent_state,)
 | 
			
		||||
            past_key_values.update(state, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
        o = rearrange(o, 'b h l d -> b l h d')
 | 
			
		||||
        if self.use_gate:
 | 
			
		||||
            g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads)
 | 
			
		||||
            o = self.norm(o, g)
 | 
			
		||||
        else:
 | 
			
		||||
            o = self.norm(o)
 | 
			
		||||
        o = rearrange(o, 'b l h d -> b l (h d)')
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
            
 | 
			
		||||
        return o, None, past_key_values
 | 
			
		||||
 | 
			
		||||
    def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
 | 
			
		||||
        param = next(self.parameters())
 | 
			
		||||
        state = tuple()
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
 | 
			
		||||
            else:
 | 
			
		||||
                # for q/k/v each
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.key_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.value_dim, self.conv_size))
 | 
			
		||||
        state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
 | 
			
		||||
        return state
 | 
			
		||||
							
								
								
									
										234
									
								
								finetune/lora/v6/fla/layers/gated_abc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										234
									
								
								finetune/lora/v6/fla/layers/gated_abc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,234 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from einops import rearrange, repeat
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
from fla.modules import (FusedRMSNormSwishGateLinear, RMSNormLinear,
 | 
			
		||||
                         RotaryEmbedding, ShortConvolution)
 | 
			
		||||
from fla.modules.activations import ACT2FN, swiglu_linear, swish
 | 
			
		||||
from fla.ops.abc.chunk_gate import chunk_gated_abc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GatedABCAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int = 1024,
 | 
			
		||||
        expand_k: float = 1.,
 | 
			
		||||
        expand_v: float = 1.,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        num_kv_heads: Optional[int] = None,
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        conv_bias: bool = False,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        num_slots: Optional[int] = None,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-5,
 | 
			
		||||
        gate_low_rank_dim: Optional[int] = None,
 | 
			
		||||
        gate_logit_normalizer: int = 16,
 | 
			
		||||
        feature_map: str = 'swish',
 | 
			
		||||
        use_rope: bool = False,
 | 
			
		||||
        use_output_gate: bool = False,
 | 
			
		||||
        use_norm: bool = True,
 | 
			
		||||
        layer_idx: Optional[int] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> GatedABCAttention:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
 | 
			
		||||
        self.num_kv_groups = self.num_heads // self.num_kv_heads
 | 
			
		||||
        self.key_dim = int(hidden_size * expand_k)
 | 
			
		||||
        self.value_dim = int(hidden_size * expand_v)
 | 
			
		||||
        self.key_dim_per_group = self.key_dim // self.num_kv_groups
 | 
			
		||||
        self.value_dim_per_group = self.value_dim // self.num_kv_groups
 | 
			
		||||
        self.head_k_dim = self.key_dim // self.num_heads
 | 
			
		||||
        self.head_v_dim = self.value_dim // self.num_heads
 | 
			
		||||
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.conv_bias = conv_bias
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
 | 
			
		||||
        if gate_low_rank_dim is None:
 | 
			
		||||
            gate_low_rank_dim = self.hidden_size // 16
 | 
			
		||||
        self.gate_low_rank_dim = gate_low_rank_dim
 | 
			
		||||
        self.gate_logit_normalizer = gate_logit_normalizer
 | 
			
		||||
 | 
			
		||||
        self.feature_map = feature_map
 | 
			
		||||
        self.use_rope = use_rope
 | 
			
		||||
        self.use_output_gate = use_output_gate
 | 
			
		||||
        self.use_norm = use_norm
 | 
			
		||||
 | 
			
		||||
        if num_slots is None:
 | 
			
		||||
            num_slots = self.head_k_dim
 | 
			
		||||
        self.num_slots = num_slots
 | 
			
		||||
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        if layer_idx is None:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
 | 
			
		||||
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
 | 
			
		||||
                "when creating this class."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
 | 
			
		||||
        self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
 | 
			
		||||
 | 
			
		||||
        if use_output_gate:
 | 
			
		||||
            self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
 | 
			
		||||
 | 
			
		||||
        if use_short_conv:
 | 
			
		||||
            self.conv_size = conv_size
 | 
			
		||||
            if share_conv_kernel:
 | 
			
		||||
                self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
 | 
			
		||||
            else:
 | 
			
		||||
                self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
 | 
			
		||||
                self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
 | 
			
		||||
                self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
 | 
			
		||||
 | 
			
		||||
        if self.use_norm:
 | 
			
		||||
            if self.use_output_gate:
 | 
			
		||||
                self.g_norm = FusedRMSNormSwishGateLinear(self.hidden_size, elementwise_affine, norm_eps)
 | 
			
		||||
            else:
 | 
			
		||||
                self.g_norm = RMSNormLinear(self.hidden_size, elementwise_affine, norm_eps)
 | 
			
		||||
        self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        if self.use_rope:
 | 
			
		||||
            self.rotary = RotaryEmbedding(self.head_k_dim)
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        lower_bound: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
 | 
			
		||||
        last_state = past_key_values[self.layer_idx] if use_cache else None
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            conv_state = last_state[0] if use_cache else None
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                # conv state is updated inplace
 | 
			
		||||
                hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                k = self.k_proj(hidden_states)
 | 
			
		||||
                v = self.v_proj(hidden_states)
 | 
			
		||||
            else:
 | 
			
		||||
                conv_state_q = last_state[0] if use_cache else None
 | 
			
		||||
                conv_state_k = last_state[1] if use_cache else None
 | 
			
		||||
                conv_state_v = last_state[2] if use_cache else None
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                k = self.k_proj(hidden_states)
 | 
			
		||||
                v = self.v_proj(hidden_states)
 | 
			
		||||
                q = self.q_conv1d(q, attention_mask, conv_state_q)
 | 
			
		||||
                k = self.k_conv1d(k, attention_mask, conv_state_k)
 | 
			
		||||
                v = self.v_conv1d(v, attention_mask, conv_state_v)
 | 
			
		||||
        else:
 | 
			
		||||
            q = self.q_proj(hidden_states)
 | 
			
		||||
            k = self.k_proj(hidden_states)
 | 
			
		||||
            v = self.v_proj(hidden_states)
 | 
			
		||||
        f = self.f_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
        if self.use_rope:
 | 
			
		||||
            q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
 | 
			
		||||
            k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
 | 
			
		||||
            seqlen_offset = 0
 | 
			
		||||
            if past_key_values is not None:
 | 
			
		||||
                seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
 | 
			
		||||
            q, k = self.rotary(q, k, seqlen_offset)
 | 
			
		||||
            q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads)
 | 
			
		||||
            k = rearrange(k, 'b n h d -> b h n d', h=self.num_kv_heads)
 | 
			
		||||
        else:
 | 
			
		||||
            q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
            if self.num_kv_groups > 1:
 | 
			
		||||
                k = repeat(k, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups)
 | 
			
		||||
            else:
 | 
			
		||||
                k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_kv_heads)
 | 
			
		||||
        if self.num_kv_groups > 1:
 | 
			
		||||
            v = repeat(v, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups)
 | 
			
		||||
            f = repeat(f, 'b n (h m) -> b (h g) n m', h=self.num_kv_heads, g=self.num_kv_groups)
 | 
			
		||||
        else:
 | 
			
		||||
            v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_kv_heads)
 | 
			
		||||
            f = rearrange(f, 'b n (h m) -> b h n m', h=self.num_kv_heads)
 | 
			
		||||
 | 
			
		||||
        if self.feature_map is not None:
 | 
			
		||||
            q, k, v = map(lambda x: ACT2FN[self.feature_map](x), (q, k, v))
 | 
			
		||||
        f = F.logsigmoid(f) / self.gate_logit_normalizer
 | 
			
		||||
        s = (1 - f.exp()).to(f.dtype)
 | 
			
		||||
        # dealing with left-padding
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            s = s.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1))
 | 
			
		||||
            v = v.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1))
 | 
			
		||||
 | 
			
		||||
        recurrent_state = last_state[-2:] if use_cache else None
 | 
			
		||||
        o, recurrent_state = chunk_gated_abc(q, k, v, s, f,
 | 
			
		||||
                                             initial_state=recurrent_state,
 | 
			
		||||
                                             output_final_state=use_cache)
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if self.use_short_conv:
 | 
			
		||||
                if self.share_conv_kernel:
 | 
			
		||||
                    last_state = (conv_state,) + recurrent_state
 | 
			
		||||
                else:
 | 
			
		||||
                    last_state = (conv_state_q, conv_state_k, conv_state_v) + recurrent_state
 | 
			
		||||
            else:
 | 
			
		||||
                last_state = recurrent_state
 | 
			
		||||
            past_key_values.update(last_state, self.layer_idx, q.shape[2])
 | 
			
		||||
 | 
			
		||||
        o = rearrange(o, 'b h t d -> b t (h d)')
 | 
			
		||||
        if self.use_norm and not self.use_output_gate:
 | 
			
		||||
            o = swish(o)
 | 
			
		||||
            o = self.g_norm(o, self.o_proj.weight, self.o_proj.bias)
 | 
			
		||||
        elif self.use_output_gate and not self.use_norm:
 | 
			
		||||
            o = swiglu_linear(self.g_proj(hidden_states), o, self.o_proj.weight, self.o_proj.bias)
 | 
			
		||||
        elif self.use_output_gate and self.use_norm:
 | 
			
		||||
            o = self.g_norm(o, self.g_proj(hidden_states), self.o_proj.weight, self.o_proj.bias)
 | 
			
		||||
        else:
 | 
			
		||||
            o = self.o_proj(o)
 | 
			
		||||
        return o, None, past_key_values
 | 
			
		||||
 | 
			
		||||
    def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
 | 
			
		||||
        param = next(self.parameters())
 | 
			
		||||
        state = tuple()
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
 | 
			
		||||
            else:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.key_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.value_dim, self.conv_size))
 | 
			
		||||
        state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots),
 | 
			
		||||
                  param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim))
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def state_size(self, sequence_length: int = 2048):
 | 
			
		||||
        return self.num_heads * self.key_dim * self.head_v_dim
 | 
			
		||||
							
								
								
									
										268
									
								
								finetune/lora/v6/fla/layers/gla.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										268
									
								
								finetune/lora/v6/fla/layers/gla.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,268 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from einops import rearrange, repeat
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
 | 
			
		||||
from fla.modules.activations import ACT2FN
 | 
			
		||||
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GatedLinearAttention(nn.Module):
 | 
			
		||||
    r"""
 | 
			
		||||
    The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635).  # noqa
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        mode (str, Optional):
 | 
			
		||||
            Which GLA kernel to use.
 | 
			
		||||
            Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
 | 
			
		||||
            Default: `chunk`.
 | 
			
		||||
        hidden_size (int, Optional):
 | 
			
		||||
            The hidden size of the input. Default: 1024.
 | 
			
		||||
        expand_k (float, Optional):
 | 
			
		||||
            The expansion ratio for the key dim. Default: 0.5.
 | 
			
		||||
        expand_v (float, Optional):
 | 
			
		||||
            The expansion ratio for the value dim. Default: 1.0.
 | 
			
		||||
        num_heads (int, Optional):
 | 
			
		||||
            The number of heads. Default: 4.
 | 
			
		||||
        num_kv_heads (int, Optional):
 | 
			
		||||
            The number of key/value heads, used for MQA. Default: None.
 | 
			
		||||
        feature_map (str, Optional):
 | 
			
		||||
            Feature map function applied to queries/keys. Default: None.
 | 
			
		||||
        use_short_conv (bool, Optional):
 | 
			
		||||
            Whether to use short convolutions. Default: `False`.
 | 
			
		||||
        conv_size (int, Optional):
 | 
			
		||||
            The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
 | 
			
		||||
        conv_bias (bool, Optional):
 | 
			
		||||
            Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
 | 
			
		||||
        share_conv_kernel (bool, Optional):
 | 
			
		||||
            Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`.
 | 
			
		||||
        use_output_gate (bool, Optional):
 | 
			
		||||
            Whether to use output gate. Default: `True`.
 | 
			
		||||
        gate_fn (str, Optional):
 | 
			
		||||
            The activation function for the output gate. Default: `swish`.
 | 
			
		||||
        elementwise_affine (bool, Optional):
 | 
			
		||||
            If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
 | 
			
		||||
        norm_eps (float, Optional):
 | 
			
		||||
            The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
 | 
			
		||||
        gate_logit_normalizer (int, Optional):
 | 
			
		||||
            The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
 | 
			
		||||
        gate_low_rank_dim (int, Optional):
 | 
			
		||||
            The low rank dim for the gate projection. Default: 16.
 | 
			
		||||
        clamp_min (float, Optional):
 | 
			
		||||
            The minimum value for the gate logits. Default: None.
 | 
			
		||||
        fuse_norm (bool, Optional):
 | 
			
		||||
            Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
 | 
			
		||||
        layer_idx (int, Optional):
 | 
			
		||||
            The index of the layer. Default: None.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        mode: str = 'chunk',
 | 
			
		||||
        hidden_size: int = 1024,
 | 
			
		||||
        expand_k: float = 0.5,
 | 
			
		||||
        expand_v: float = 1.0,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        num_kv_heads: Optional[int] = None,
 | 
			
		||||
        feature_map: Optional[str] = None,
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        conv_bias: bool = False,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        use_output_gate: bool = True,
 | 
			
		||||
        gate_fn: str = 'swish',
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-5,
 | 
			
		||||
        gate_logit_normalizer: int = 16,
 | 
			
		||||
        gate_low_rank_dim: int = 16,
 | 
			
		||||
        clamp_min: Optional[float] = None,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        layer_idx: int = None,
 | 
			
		||||
    ) -> GatedLinearAttention:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
 | 
			
		||||
        self.num_kv_groups = self.num_heads // self.num_kv_heads
 | 
			
		||||
        self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
 | 
			
		||||
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.conv_bias = conv_bias
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
        self.use_output_gate = use_output_gate
 | 
			
		||||
 | 
			
		||||
        self.key_dim = int(hidden_size * expand_k)
 | 
			
		||||
        self.value_dim = int(hidden_size * expand_v)
 | 
			
		||||
        self.key_dim_per_group = self.key_dim // self.num_kv_groups
 | 
			
		||||
        self.value_dim_per_group = self.value_dim // self.num_kv_groups
 | 
			
		||||
        self.clamp_min = clamp_min
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
 | 
			
		||||
        assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
        assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
 | 
			
		||||
        self.head_qk_dim = self.key_dim // num_heads
 | 
			
		||||
        self.head_v_dim = self.value_dim // num_heads
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
 | 
			
		||||
        if self.use_output_gate:
 | 
			
		||||
            self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
 | 
			
		||||
 | 
			
		||||
        if use_short_conv:
 | 
			
		||||
            self.conv_size = conv_size
 | 
			
		||||
            if share_conv_kernel:
 | 
			
		||||
                self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
 | 
			
		||||
            else:
 | 
			
		||||
                self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
 | 
			
		||||
                self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
 | 
			
		||||
                self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
 | 
			
		||||
 | 
			
		||||
        self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
 | 
			
		||||
                                     nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
 | 
			
		||||
        self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        if gate_fn == 'swish' and fuse_norm and use_output_gate:
 | 
			
		||||
            self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
            self.fuse_norm_and_gate = True
 | 
			
		||||
        else:
 | 
			
		||||
            self.fuse_norm_and_gate = False
 | 
			
		||||
            self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
            self.gate_fn = ACT2FN[gate_fn]
 | 
			
		||||
 | 
			
		||||
        self.gate_logit_normalizer = gate_logit_normalizer
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
        # launching the triton kernel for just one token will actually be slower
 | 
			
		||||
        mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
 | 
			
		||||
 | 
			
		||||
        last_state = past_key_values[self.layer_idx] if use_cache else None
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            conv_state = last_state[0] if use_cache else None
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                # conv state is updated inplace
 | 
			
		||||
                hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                k = self.k_proj(hidden_states)
 | 
			
		||||
                v = self.v_proj(hidden_states)
 | 
			
		||||
            else:
 | 
			
		||||
                conv_state_q = last_state[0] if use_cache else None
 | 
			
		||||
                conv_state_k = last_state[1] if use_cache else None
 | 
			
		||||
                conv_state_v = last_state[2] if use_cache else None
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                k = self.k_proj(hidden_states)
 | 
			
		||||
                v = self.v_proj(hidden_states)
 | 
			
		||||
                q = self.q_conv1d(q, attention_mask, conv_state_q)
 | 
			
		||||
                k = self.k_conv1d(k, attention_mask, conv_state_k)
 | 
			
		||||
                v = self.v_conv1d(v, attention_mask, conv_state_v)
 | 
			
		||||
        else:
 | 
			
		||||
            q = self.q_proj(hidden_states)
 | 
			
		||||
            k = self.k_proj(hidden_states)
 | 
			
		||||
            v = self.v_proj(hidden_states)
 | 
			
		||||
        gk = self.gk_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
        if self.feature_map_fn is not None:
 | 
			
		||||
            q, k = map(self.feature_map_fn, (q, k))
 | 
			
		||||
        # dealing with left-padding
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            v = v.mul_(attention_mask.unsqueeze(-1))
 | 
			
		||||
        q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads)
 | 
			
		||||
        if self.num_kv_groups > 1:
 | 
			
		||||
            k, v, gk = (repeat(x, 'b l (h d) -> b (h g) l d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v, gk))
 | 
			
		||||
        else:
 | 
			
		||||
            k, v, gk = (rearrange(x, 'b l (h d) -> b h l d', h=self.num_kv_heads) for x in (k, v, gk))
 | 
			
		||||
        gk = F.logsigmoid(gk) / self.gate_logit_normalizer
 | 
			
		||||
 | 
			
		||||
        if self.clamp_min is not None:
 | 
			
		||||
            gk = torch.clamp_min(gk, self.clamp_min)
 | 
			
		||||
 | 
			
		||||
        recurrent_state = last_state[-1] if use_cache else None
 | 
			
		||||
        if mode == 'fused_recurrent':
 | 
			
		||||
            o, recurrent_state = fused_recurrent_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'fused_chunk':
 | 
			
		||||
            o, recurrent_state = fused_chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'chunk':
 | 
			
		||||
            o, recurrent_state = chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Not supported mode `{mode}`.")
 | 
			
		||||
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if self.use_short_conv:
 | 
			
		||||
                if self.share_conv_kernel:
 | 
			
		||||
                    last_state = (conv_state, recurrent_state)
 | 
			
		||||
                else:
 | 
			
		||||
                    last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
 | 
			
		||||
            else:
 | 
			
		||||
                last_state = (recurrent_state,)
 | 
			
		||||
            past_key_values.update(last_state, self.layer_idx, q.shape[2])
 | 
			
		||||
 | 
			
		||||
        o = rearrange(o, 'b h l d -> b l h d')
 | 
			
		||||
        if self.use_output_gate:
 | 
			
		||||
            g = self.g_proj(hidden_states)
 | 
			
		||||
            if self.fuse_norm_and_gate:
 | 
			
		||||
                g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
 | 
			
		||||
                o = self.g_norm_swish_gate(o, g)
 | 
			
		||||
                o = rearrange(o, 'b l h d -> b l (h d)')
 | 
			
		||||
            else:
 | 
			
		||||
                o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
 | 
			
		||||
                o = o * self.gate_fn(g)
 | 
			
		||||
        else:
 | 
			
		||||
            o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
 | 
			
		||||
        return o, None, past_key_values
 | 
			
		||||
 | 
			
		||||
    def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
 | 
			
		||||
        param = next(self.parameters())
 | 
			
		||||
        state = tuple()
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
 | 
			
		||||
            else:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.key_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.value_dim, self.conv_size))
 | 
			
		||||
        state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def state_size(self, **kwargs) -> int:
 | 
			
		||||
        state_size = self.key_dim * self.head_v_dim
 | 
			
		||||
        for module in self.children():
 | 
			
		||||
            if isinstance(module, ShortConvolution):
 | 
			
		||||
                state_size += module.state_size
 | 
			
		||||
        return state_size
 | 
			
		||||
							
								
								
									
										165
									
								
								finetune/lora/v6/fla/layers/hgrn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								finetune/lora/v6/fla/layers/hgrn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,165 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
from fla.modules import FusedRMSNormSwishGate, ShortConvolution
 | 
			
		||||
from fla.modules.activations import swiglu
 | 
			
		||||
from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRNAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        mode: str = 'chunk',
 | 
			
		||||
        hidden_size: int = 1024,
 | 
			
		||||
        num_heads: Optional[int] = None,
 | 
			
		||||
        expand_ratio: Optional[int] = 1,
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        conv_bias: bool = False,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-5,
 | 
			
		||||
        layer_idx: int = None
 | 
			
		||||
    ) -> HGRNAttention:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.expand_ratio = expand_ratio
 | 
			
		||||
        self.input_dim = int(hidden_size * expand_ratio)
 | 
			
		||||
        self.head_dim = self.input_dim // self.num_heads
 | 
			
		||||
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.conv_bias = conv_bias
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
 | 
			
		||||
        assert self.hidden_size % num_heads == 0, f"hidden size must be divisible by num_heads of {num_heads}"
 | 
			
		||||
 | 
			
		||||
        self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
 | 
			
		||||
        self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
 | 
			
		||||
        self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
 | 
			
		||||
 | 
			
		||||
        if use_short_conv:
 | 
			
		||||
            self.conv_size = conv_size
 | 
			
		||||
            if share_conv_kernel:
 | 
			
		||||
                self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
 | 
			
		||||
            else:
 | 
			
		||||
                self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
 | 
			
		||||
                self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
 | 
			
		||||
                self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
 | 
			
		||||
 | 
			
		||||
        self.g_norm = FusedRMSNormSwishGate(self.input_dim, elementwise_affine, norm_eps)
 | 
			
		||||
        self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        lower_bound: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
        # launching the triton kernel for just one token will actually be slower
 | 
			
		||||
        mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
 | 
			
		||||
 | 
			
		||||
        last_state = past_key_values[self.layer_idx] if use_cache else None
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            conv_state = last_state[0] if use_cache else None
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                # conv state is updated inplace
 | 
			
		||||
                hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
 | 
			
		||||
                i = self.i_proj(hidden_states)
 | 
			
		||||
                f = self.f_proj(hidden_states)
 | 
			
		||||
            else:
 | 
			
		||||
                conv_state_i = last_state[2] if use_cache else None
 | 
			
		||||
                conv_state_f = last_state[1] if use_cache else None
 | 
			
		||||
                i = self.i_conv1d(self.i_proj(hidden_states), attention_mask, conv_state_i)
 | 
			
		||||
                f = self.f_conv1d(self.f_proj(hidden_states), attention_mask, conv_state_f)
 | 
			
		||||
        else:
 | 
			
		||||
            i = self.i_proj(hidden_states)
 | 
			
		||||
            f = self.f_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # the lower bound for the first layer is zero
 | 
			
		||||
        if lower_bound is None or self.layer_idx == 0:
 | 
			
		||||
            i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
 | 
			
		||||
        else:
 | 
			
		||||
            g = lower_bound + (1 - lower_bound) * f.sigmoid()
 | 
			
		||||
            i, f = swiglu(i, 1 - g), g.log()
 | 
			
		||||
 | 
			
		||||
        # dealing with left-padding
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            i = i.mul_(attention_mask.unsqueeze(-1))
 | 
			
		||||
        i, f = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (i, f))
 | 
			
		||||
 | 
			
		||||
        recurrent_state = last_state[-1] if use_cache else None
 | 
			
		||||
        if mode == 'chunk':
 | 
			
		||||
            o, recurrent_state = chunk_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'fused_recurrent':
 | 
			
		||||
            o, recurrent_state = fused_recurrent_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Not supported mode `{mode}`.")
 | 
			
		||||
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if self.use_short_conv:
 | 
			
		||||
                if self.share_conv_kernel:
 | 
			
		||||
                    last_state = (conv_state, recurrent_state)
 | 
			
		||||
                else:
 | 
			
		||||
                    last_state = (conv_state_i, conv_state_f, recurrent_state)
 | 
			
		||||
            else:
 | 
			
		||||
                last_state = (recurrent_state,)
 | 
			
		||||
            past_key_values.update(last_state, self.layer_idx, i.shape[2])
 | 
			
		||||
 | 
			
		||||
        o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)'))
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
 | 
			
		||||
        return o, None, past_key_values
 | 
			
		||||
 | 
			
		||||
    def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
 | 
			
		||||
        param = next(self.parameters())
 | 
			
		||||
        state = tuple()
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
 | 
			
		||||
            else:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.hidden_size, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.hidden_size, self.conv_size))
 | 
			
		||||
        state += (param.new_zeros(batch_size, self.num_heads, self.head_dim),)
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def state_size(self, **kwargs) -> int:
 | 
			
		||||
        state_size = self.hidden_size
 | 
			
		||||
        for module in self.children():
 | 
			
		||||
            if isinstance(module, ShortConvolution):
 | 
			
		||||
                state_size += module.state_size
 | 
			
		||||
        return state_size
 | 
			
		||||
							
								
								
									
										186
									
								
								finetune/lora/v6/fla/layers/hgrn2.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								finetune/lora/v6/fla/layers/hgrn2.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,186 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
from fla.modules import RMSNorm, ShortConvolution
 | 
			
		||||
from fla.modules.activations import swish
 | 
			
		||||
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRN2Attention(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        mode: str = 'chunk',
 | 
			
		||||
        hidden_size: int = 1024,
 | 
			
		||||
        num_heads: Optional[int] = None,
 | 
			
		||||
        expand_ratio: Optional[int] = 128,
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        conv_bias: bool = False,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-5,
 | 
			
		||||
        layer_idx: int = None
 | 
			
		||||
    ) -> HGRN2Attention:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
 | 
			
		||||
        if expand_ratio is None and num_heads is not None:
 | 
			
		||||
            expand_ratio = hidden_size // num_heads
 | 
			
		||||
        elif expand_ratio is not None and num_heads is None:
 | 
			
		||||
            num_heads = hidden_size // expand_ratio
 | 
			
		||||
        else:
 | 
			
		||||
            raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.expand_ratio = expand_ratio
 | 
			
		||||
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.conv_bias = conv_bias
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
 | 
			
		||||
        self.forget_dim = int(self.num_heads * self.expand_ratio)
 | 
			
		||||
        self.input_dim = hidden_size
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
 | 
			
		||||
        assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
        assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
 | 
			
		||||
        self.head_f_dim = self.expand_ratio
 | 
			
		||||
        self.head_i_dim = self.hidden_size // num_heads
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
 | 
			
		||||
        self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
 | 
			
		||||
        self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
 | 
			
		||||
 | 
			
		||||
        if use_short_conv:
 | 
			
		||||
            self.conv_size = conv_size
 | 
			
		||||
            if share_conv_kernel:
 | 
			
		||||
                self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
 | 
			
		||||
            else:
 | 
			
		||||
                self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu')
 | 
			
		||||
                self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu')
 | 
			
		||||
                self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
 | 
			
		||||
 | 
			
		||||
        self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, norm_eps)
 | 
			
		||||
        self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        lower_bound: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
        # launching the triton kernel for just one token will actually be slower
 | 
			
		||||
        mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
 | 
			
		||||
 | 
			
		||||
        last_state = past_key_values[self.layer_idx] if use_cache else None
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            conv_state = last_state[0] if use_cache else None
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                # conv state is updated inplace
 | 
			
		||||
                hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                f = self.f_proj(hidden_states)
 | 
			
		||||
                i = self.i_proj(hidden_states)
 | 
			
		||||
            else:
 | 
			
		||||
                conv_state_q = last_state[0] if use_cache else None
 | 
			
		||||
                conv_state_f = last_state[1] if use_cache else None
 | 
			
		||||
                conv_state_i = last_state[2] if use_cache else None
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                f = self.f_proj(hidden_states)
 | 
			
		||||
                i = self.i_proj(hidden_states)
 | 
			
		||||
                q = self.q_conv1d(q, attention_mask, conv_state_q)
 | 
			
		||||
                f = self.f_conv1d(f, attention_mask, conv_state_f)
 | 
			
		||||
                i = self.i_conv1d(i, attention_mask, conv_state_i)
 | 
			
		||||
        else:
 | 
			
		||||
            q = self.q_proj(hidden_states)
 | 
			
		||||
            f = self.f_proj(hidden_states)
 | 
			
		||||
            i = self.i_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # dealing with left-padding
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            i = i.mul_(attention_mask.unsqueeze(-1))
 | 
			
		||||
 | 
			
		||||
        q = swish(q)
 | 
			
		||||
        # the lower bound for the first layer is zero
 | 
			
		||||
        if lower_bound is None or self.layer_idx == 0:
 | 
			
		||||
            k, g = 1 - f.sigmoid(), F.logsigmoid(f)
 | 
			
		||||
        else:
 | 
			
		||||
            g = lower_bound + (1 - lower_bound) * f.sigmoid()
 | 
			
		||||
            k, g = 1 - g, g.log()
 | 
			
		||||
        q, k, i, g = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, i, g))
 | 
			
		||||
 | 
			
		||||
        recurrent_state = last_state[-1] if use_cache else None
 | 
			
		||||
        if mode == 'fused_recurrent':
 | 
			
		||||
            o, recurrent_state = fused_recurrent_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'fused_chunk':
 | 
			
		||||
            o, recurrent_state = fused_chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'chunk':
 | 
			
		||||
            o, recurrent_state = chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Not supported mode `{mode}`.")
 | 
			
		||||
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if self.use_short_conv:
 | 
			
		||||
                if self.share_conv_kernel:
 | 
			
		||||
                    last_state = (conv_state, recurrent_state)
 | 
			
		||||
                else:
 | 
			
		||||
                    last_state = (conv_state_q, conv_state_f, conv_state_i, recurrent_state)
 | 
			
		||||
            else:
 | 
			
		||||
                last_state = (recurrent_state,)
 | 
			
		||||
            past_key_values.update(last_state, self.layer_idx, q.shape[2])
 | 
			
		||||
 | 
			
		||||
        o = self.g_norm(rearrange(o, 'b h l d -> b l (h d)'))
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
 | 
			
		||||
        return o, None, past_key_values
 | 
			
		||||
 | 
			
		||||
    def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
 | 
			
		||||
        param = next(self.parameters())
 | 
			
		||||
        state = tuple()
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
 | 
			
		||||
            else:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.forget_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.forget_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.input_dim, self.conv_size))
 | 
			
		||||
        state += (param.new_zeros(batch_size, self.num_heads, self.head_f_dim, self.head_i_dim),)
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def state_size(self, **kwargs) -> int:
 | 
			
		||||
        state_size = self.forget_dim * self.head_i_dim
 | 
			
		||||
        for module in self.children():
 | 
			
		||||
            if isinstance(module, ShortConvolution):
 | 
			
		||||
                state_size += module.state_size
 | 
			
		||||
        return state_size
 | 
			
		||||
							
								
								
									
										156
									
								
								finetune/lora/v6/fla/layers/linear_attn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										156
									
								
								finetune/lora/v6/fla/layers/linear_attn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,156 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
 | 
			
		||||
from fla.modules import RMSNorm
 | 
			
		||||
from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap,
 | 
			
		||||
                                     HedgehogFeatureMap, T2RFeatureMap)
 | 
			
		||||
from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn,
 | 
			
		||||
                                 fused_recurrent_linear_attn)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearAttention(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: str = 1024,
 | 
			
		||||
        expand_k: int = 1.0,
 | 
			
		||||
        expand_v: int = 1.0,
 | 
			
		||||
        num_heads: int = 8,
 | 
			
		||||
        mode: str = 'chunk',
 | 
			
		||||
        feature_map: str = 'elementwise_product',
 | 
			
		||||
        tie_feature_map_qk: bool = False,
 | 
			
		||||
        output_norm: str = 'rmsnorm',
 | 
			
		||||
        norm_q: bool = False,
 | 
			
		||||
        norm_k: bool = False,
 | 
			
		||||
        # standard linear attention normalization
 | 
			
		||||
        do_feature_map_norm: bool = False,
 | 
			
		||||
        elementwise_affine: bool = True,
 | 
			
		||||
        norm_eps: float = 1e-5,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp',
 | 
			
		||||
                               'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`."
 | 
			
		||||
 | 
			
		||||
        assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`."
 | 
			
		||||
 | 
			
		||||
        self.hidden_size
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        self.key_dim = int(hidden_size * expand_k)
 | 
			
		||||
        self.value_dim = int(hidden_size * expand_v)
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
 | 
			
		||||
        assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
 | 
			
		||||
        assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
        assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
 | 
			
		||||
        self.head_qk_dim = self.key_dim // num_heads
 | 
			
		||||
        self.head_v_dim = self.value_dim // num_heads
 | 
			
		||||
 | 
			
		||||
        if feature_map == 'hedgehog':
 | 
			
		||||
            if tie_feature_map_qk:
 | 
			
		||||
                self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
            else:
 | 
			
		||||
                self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
                self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
 | 
			
		||||
        elif feature_map == 't2r':
 | 
			
		||||
            if tie_feature_map_qk:
 | 
			
		||||
                self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
            else:
 | 
			
		||||
                self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
                self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
 | 
			
		||||
        elif feature_map == 'elementwise_product':
 | 
			
		||||
            if tie_feature_map_qk:
 | 
			
		||||
                self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
            else:
 | 
			
		||||
                self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
                self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
 | 
			
		||||
        elif feature_map == 'dpfp':
 | 
			
		||||
            self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
            self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim)
 | 
			
		||||
 | 
			
		||||
        elif feature_map == 'elu':
 | 
			
		||||
            def elu(x):
 | 
			
		||||
                return F.elu(x) + 1
 | 
			
		||||
            self.feature_map_q = elu
 | 
			
		||||
            self.feature_map_k = elu
 | 
			
		||||
 | 
			
		||||
        elif feature_map == 'relu':
 | 
			
		||||
            self.feature_map_q = nn.ReLU()
 | 
			
		||||
            self.feature_map_k = nn.ReLU()
 | 
			
		||||
 | 
			
		||||
        elif feature_map == 'identity':
 | 
			
		||||
            self.feature_map_q = nn.Identity()
 | 
			
		||||
            self.feature_map_k = nn.Identity()
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        self.do_feature_map_norm = do_feature_map_norm
 | 
			
		||||
        if output_norm == 'rmsnorm':
 | 
			
		||||
            self.norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
        elif output_norm == 'identity':
 | 
			
		||||
            self.norm = nn.Identity()
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
 | 
			
		||||
        self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        self.norm_q = norm_q
 | 
			
		||||
        self.norm_k = norm_k
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        mode = self.mode
 | 
			
		||||
        q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
        k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
        v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
        q = self.feature_map_q(q)
 | 
			
		||||
        k = self.feature_map_k(k)
 | 
			
		||||
        if self.norm_q:
 | 
			
		||||
            q = q / (q.sum(-1, keepdim=True) + 1e-4)
 | 
			
		||||
        if self.norm_k:
 | 
			
		||||
            k = k / (k.sum(-1, keepdim=True) + 1e-4)
 | 
			
		||||
 | 
			
		||||
        if mode == 'chunk':
 | 
			
		||||
            o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
 | 
			
		||||
        elif mode == 'fused_chunk':
 | 
			
		||||
            o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
 | 
			
		||||
        elif mode == 'fused_recurrent':
 | 
			
		||||
            o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError
 | 
			
		||||
        o = self.norm(o)
 | 
			
		||||
        o = rearrange(o, 'b h n d -> b n (h d)')
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
        return o
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    import torch
 | 
			
		||||
    batch = 4
 | 
			
		||||
    seq_len = 1024
 | 
			
		||||
    hidden_size = 1024
 | 
			
		||||
    x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
 | 
			
		||||
    model = LinearAttention(hidden_size, feature_map='dplp').to(torch.bfloat16).cuda()
 | 
			
		||||
    y = model(x)
 | 
			
		||||
    print(y.shape)
 | 
			
		||||
    y.sum().backward()
 | 
			
		||||
    print(x.grad.shape)
 | 
			
		||||
							
								
								
									
										271
									
								
								finetune/lora/v6/fla/layers/multiscale_retention.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										271
									
								
								finetune/lora/v6/fla/layers/multiscale_retention.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,271 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from einops import rearrange, repeat
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
 | 
			
		||||
from fla.modules.rotary import RotaryEmbedding
 | 
			
		||||
from fla.ops.retention import (chunk_retention, fused_chunk_retention,
 | 
			
		||||
                               fused_recurrent_retention, parallel_retention)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MultiScaleRetention(nn.Module):
 | 
			
		||||
    r"""
 | 
			
		||||
    The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf).  # noqa
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        mode (str, Optional):
 | 
			
		||||
            Which Retention kernel to use.
 | 
			
		||||
            Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
 | 
			
		||||
            Default: `fused_chunk`.
 | 
			
		||||
        hidden_size (int, Optional):
 | 
			
		||||
            The hidden size of the input. Default: 1024.
 | 
			
		||||
        expand_k (float, Optional):
 | 
			
		||||
            The expansion ratio for the key dim. Default: 1.0.
 | 
			
		||||
        expand_v (float, Optional):
 | 
			
		||||
            The expansion ratio for the value dim. Default: 2.0.
 | 
			
		||||
        num_heads (int, Optional):
 | 
			
		||||
            The number of heads. Default: 8.
 | 
			
		||||
        num_kv_heads (int, Optional):
 | 
			
		||||
            The number of key/value heads, used for MQA. Default: None.
 | 
			
		||||
        feature_map (str, Optional):
 | 
			
		||||
            Feature map function applied to queries/keys. Default: None.
 | 
			
		||||
        use_short_conv (bool, Optional):
 | 
			
		||||
            Whether to use short convolutions. Default: `False`.
 | 
			
		||||
        conv_size (int, Optional):
 | 
			
		||||
            The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
 | 
			
		||||
        conv_bias (bool, Optional):
 | 
			
		||||
            Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
 | 
			
		||||
        share_conv_kernel (bool, Optional):
 | 
			
		||||
            Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`.
 | 
			
		||||
        use_output_gate (bool, Optional):
 | 
			
		||||
            Whether to use output gate. Default: `True`.
 | 
			
		||||
        gate_fn (str, Optional):
 | 
			
		||||
            The activation function for the output gate. Default: `swish`.
 | 
			
		||||
        elementwise_affine (bool, Optional):
 | 
			
		||||
            If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
 | 
			
		||||
        norm_eps (float, Optional):
 | 
			
		||||
            The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
 | 
			
		||||
        fuse_norm (bool, Optional):
 | 
			
		||||
            Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
 | 
			
		||||
        layer_idx (int, Optional):
 | 
			
		||||
            The index of the layer. Default: None.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        mode: str = 'fused_chunk',
 | 
			
		||||
        hidden_size: int = 1024,
 | 
			
		||||
        expand_k: float = 1.0,
 | 
			
		||||
        expand_v: float = 2.0,
 | 
			
		||||
        num_heads: int = 8,
 | 
			
		||||
        num_kv_heads: Optional[int] = None,
 | 
			
		||||
        feature_map: Optional[str] = None,
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        conv_bias: bool = False,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        use_output_gate: bool = True,
 | 
			
		||||
        gate_fn: str = 'swish',
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-5,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        layer_idx: int = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> MultiScaleRetention:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
 | 
			
		||||
        self.num_kv_groups = self.num_heads // self.num_kv_heads
 | 
			
		||||
        self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
 | 
			
		||||
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.conv_bias = conv_bias
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
        self.use_output_gate = use_output_gate
 | 
			
		||||
 | 
			
		||||
        self.key_dim = int(hidden_size * expand_k)
 | 
			
		||||
        self.value_dim = int(hidden_size * expand_v)
 | 
			
		||||
        self.key_dim_per_group = self.key_dim // self.num_kv_groups
 | 
			
		||||
        self.value_dim_per_group = self.value_dim // self.num_kv_groups
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
 | 
			
		||||
        assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
        assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
 | 
			
		||||
        self.head_qk_dim = self.key_dim // num_heads
 | 
			
		||||
        self.head_v_dim = self.value_dim // num_heads
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
 | 
			
		||||
        if self.use_output_gate:
 | 
			
		||||
            self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
 | 
			
		||||
 | 
			
		||||
        if use_short_conv:
 | 
			
		||||
            self.conv_size = conv_size
 | 
			
		||||
            if share_conv_kernel:
 | 
			
		||||
                self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
 | 
			
		||||
            else:
 | 
			
		||||
                self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
 | 
			
		||||
                self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
 | 
			
		||||
                self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
 | 
			
		||||
 | 
			
		||||
        self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        if gate_fn == 'swish' and fuse_norm and use_output_gate:
 | 
			
		||||
            self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
            self.fuse_norm_and_gate = True
 | 
			
		||||
        else:
 | 
			
		||||
            self.fuse_norm_and_gate = False
 | 
			
		||||
            self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
            self.gate_fn = ACT2FN[gate_fn]
 | 
			
		||||
 | 
			
		||||
        # TODO: fix this issue
 | 
			
		||||
        # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
 | 
			
		||||
        # Ideally, we would want to support arbitrary d_head_qk
 | 
			
		||||
        assert self.head_qk_dim <= 256, "head_qk_dim must be less than or equal to 256"
 | 
			
		||||
        self.rotary = RotaryEmbedding(dim=self.head_qk_dim)
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
        # launching the triton kernel for just one token will actually be slower
 | 
			
		||||
        mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
 | 
			
		||||
 | 
			
		||||
        last_state = past_key_values[self.layer_idx] if use_cache else None
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            conv_state = last_state[0] if use_cache else None
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                # conv state is updated inplace
 | 
			
		||||
                hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                k = self.k_proj(hidden_states)
 | 
			
		||||
                v = self.v_proj(hidden_states)
 | 
			
		||||
            else:
 | 
			
		||||
                conv_state_q = last_state[0] if use_cache else None
 | 
			
		||||
                conv_state_k = last_state[1] if use_cache else None
 | 
			
		||||
                conv_state_v = last_state[2] if use_cache else None
 | 
			
		||||
                q = self.q_proj(hidden_states)
 | 
			
		||||
                k = self.k_proj(hidden_states)
 | 
			
		||||
                v = self.v_proj(hidden_states)
 | 
			
		||||
                q = self.q_conv1d(q, attention_mask, conv_state_q)
 | 
			
		||||
                k = self.k_conv1d(k, attention_mask, conv_state_k)
 | 
			
		||||
                v = self.v_conv1d(v, attention_mask, conv_state_v)
 | 
			
		||||
        else:
 | 
			
		||||
            q = self.q_proj(hidden_states)
 | 
			
		||||
            k = self.k_proj(hidden_states)
 | 
			
		||||
            v = self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # dealing with left-padding
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            v = v.mul_(attention_mask.unsqueeze(-1))
 | 
			
		||||
        q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
 | 
			
		||||
        k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
 | 
			
		||||
        if self.feature_map_fn is not None:
 | 
			
		||||
            q, k = map(self.feature_map_fn, (q, k))
 | 
			
		||||
 | 
			
		||||
        seqlen_offset, max_seqlen = 0, None
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
 | 
			
		||||
            max_seqlen = q.shape[1] + seqlen_offset
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            # to deliminate the offsets of padding tokens
 | 
			
		||||
            seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
 | 
			
		||||
            max_seqlen = q.shape[1] + max(seqlen_offset)
 | 
			
		||||
        q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
 | 
			
		||||
        q = q.transpose(1, 2)
 | 
			
		||||
        if self.num_kv_groups > 1:
 | 
			
		||||
            k = repeat(k, 'b t h d -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
 | 
			
		||||
            v = repeat(v, 'b t (h d) -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
 | 
			
		||||
        else:
 | 
			
		||||
            k, v = rearrange(k, 'b t h d -> b h t d'), rearrange(v, 'b t (h d) -> b h t d', h=self.num_kv_heads)
 | 
			
		||||
 | 
			
		||||
        state = last_state[-1] if use_cache else None
 | 
			
		||||
        if mode == 'chunk':
 | 
			
		||||
            o, recurrent_state = chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'fused_chunk':
 | 
			
		||||
            o, recurrent_state = fused_chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'parallel':
 | 
			
		||||
            o, recurrent_state = parallel_retention(q, k, v, initial_state=state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'fused_recurrent':
 | 
			
		||||
            o, recurrent_state = fused_recurrent_retention(q, k, v, initial_state=state, output_final_state=use_cache)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Not supported mode `{mode}`.")
 | 
			
		||||
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if self.use_short_conv:
 | 
			
		||||
                if self.share_conv_kernel:
 | 
			
		||||
                    last_state = (conv_state, recurrent_state)
 | 
			
		||||
                else:
 | 
			
		||||
                    last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
 | 
			
		||||
            else:
 | 
			
		||||
                last_state = (recurrent_state,)
 | 
			
		||||
            past_key_values.update(last_state, self.layer_idx, q.shape[2])
 | 
			
		||||
 | 
			
		||||
        o = rearrange(o, 'b h l d -> b l h d')
 | 
			
		||||
        if self.use_output_gate:
 | 
			
		||||
            g = self.g_proj(hidden_states)
 | 
			
		||||
            if self.fuse_norm_and_gate:
 | 
			
		||||
                g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
 | 
			
		||||
                o = self.g_norm_swish_gate(o, g)
 | 
			
		||||
                o = rearrange(o, 'b l h d -> b l (h d)')
 | 
			
		||||
            else:
 | 
			
		||||
                o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
 | 
			
		||||
                o = o * self.gate_fn(g)
 | 
			
		||||
        else:
 | 
			
		||||
            o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
 | 
			
		||||
        return o, None, past_key_values
 | 
			
		||||
 | 
			
		||||
    def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
 | 
			
		||||
        param = next(self.parameters())
 | 
			
		||||
        state = tuple()
 | 
			
		||||
        if self.use_short_conv:
 | 
			
		||||
            if self.share_conv_kernel:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
 | 
			
		||||
            else:
 | 
			
		||||
                state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.key_dim, self.conv_size),
 | 
			
		||||
                          param.new_zeros(batch_size, self.value_dim, self.conv_size))
 | 
			
		||||
        state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def state_size(self, **kwargs) -> int:
 | 
			
		||||
        state_size = self.key_dim * self.head_v_dim
 | 
			
		||||
        for module in self.children():
 | 
			
		||||
            if isinstance(module, ShortConvolution):
 | 
			
		||||
                state_size += module.state_size
 | 
			
		||||
        return state_size
 | 
			
		||||
							
								
								
									
										137
									
								
								finetune/lora/v6/fla/layers/rebased.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								finetune/lora/v6/fla/layers/rebased.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,137 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
 | 
			
		||||
from fla.modules.feature_map import RebasedFeatureMap
 | 
			
		||||
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
 | 
			
		||||
from fla.ops.rebased import parallel_rebased
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReBasedLinearAttention(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        l_max: int = 2048,
 | 
			
		||||
        feature_dim: int = 16,
 | 
			
		||||
        num_key_value_heads: int = 16,
 | 
			
		||||
        num_heads: int = 16,
 | 
			
		||||
        use_gamma: Optional[bool] = True,
 | 
			
		||||
        use_beta: Optional[bool] = True,
 | 
			
		||||
        normalize: Optional[bool] = True,
 | 
			
		||||
        causal: bool = True,
 | 
			
		||||
        eps: float = 1e-5,
 | 
			
		||||
        mode: str = "parallel",
 | 
			
		||||
        layer_idx: Optional[int] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> ReBasedLinearAttention:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.l_max = l_max
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        assert self.mode in ["fused_chunk", "parallel", 'chunk']
 | 
			
		||||
 | 
			
		||||
        # linear attention
 | 
			
		||||
        self.feature_dim = feature_dim
 | 
			
		||||
        self.num_key_value_heads = num_key_value_heads
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.head_dim = self.hidden_size // self.num_key_value_heads
 | 
			
		||||
        self.use_gamma = use_gamma
 | 
			
		||||
        self.use_beta = use_beta
 | 
			
		||||
        self.normalize = normalize
 | 
			
		||||
        self.causal = causal
 | 
			
		||||
 | 
			
		||||
        self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
 | 
			
		||||
        self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
 | 
			
		||||
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 | 
			
		||||
        self.dropout = nn.Identity()
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states: torch.Tensor, **kwargs):
 | 
			
		||||
        mode = self.mode
 | 
			
		||||
        q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
 | 
			
		||||
        q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
 | 
			
		||||
        q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
 | 
			
		||||
        if mode == "fused_chunk":
 | 
			
		||||
            o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
 | 
			
		||||
        elif mode == 'chunk':
 | 
			
		||||
            o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
 | 
			
		||||
        elif mode == 'parallel':
 | 
			
		||||
            assert q.shape[-1] <= 128
 | 
			
		||||
            o = parallel_rebased(q, k, v, self.eps, True, True)
 | 
			
		||||
        o = rearrange(o, "b h l d -> b l (h d)")
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
        o = self.dropout(o)
 | 
			
		||||
        return o
 | 
			
		||||
 | 
			
		||||
    # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
 | 
			
		||||
    def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        x (torch.Tensor): tensor of shape (b, d, l)
 | 
			
		||||
        y (torch.Tensor): tensor of shape (b, d, l)
 | 
			
		||||
        """
 | 
			
		||||
        # hidden_states = hidden_states.transpose(1, 2)
 | 
			
		||||
        b, l, _ = hidden_states.size()
 | 
			
		||||
        q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
 | 
			
		||||
 | 
			
		||||
        q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
 | 
			
		||||
        k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
 | 
			
		||||
        v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        # Linear attention
 | 
			
		||||
        q, k = self.feature_map(q), self.feature_map(k)
 | 
			
		||||
        q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
 | 
			
		||||
 | 
			
		||||
        # Compute attention
 | 
			
		||||
        if self.causal:
 | 
			
		||||
            y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
 | 
			
		||||
        else:
 | 
			
		||||
            y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
 | 
			
		||||
        y = rearrange(y, 'b h l d -> b l (h d)')
 | 
			
		||||
        y = self.o_proj(y.to(hidden_states.dtype))
 | 
			
		||||
        y = self.dropout(y)
 | 
			
		||||
        return y.to(hidden_states.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    batch = 4
 | 
			
		||||
    seq_len = 1024
 | 
			
		||||
    hidden_size = 1024
 | 
			
		||||
    dtype = torch.float32
 | 
			
		||||
    x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
 | 
			
		||||
    dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
 | 
			
		||||
    model = ReBasedLinearAttention(hidden_size=hidden_size, mode='parallel').to(dtype).cuda()
 | 
			
		||||
 | 
			
		||||
    y = model(x)
 | 
			
		||||
    y.backward(dy, retain_graph=True)
 | 
			
		||||
    x_grad, x.grad = x.grad, None
 | 
			
		||||
    print(model.mode)
 | 
			
		||||
    model.mode = 'fused_chunk'
 | 
			
		||||
    y2 = model(x)
 | 
			
		||||
    print(model.mode)
 | 
			
		||||
    y2.backward(dy)
 | 
			
		||||
    # assert y.allclose(y2, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
 | 
			
		||||
    print("Pass")
 | 
			
		||||
							
								
								
									
										264
									
								
								finetune/lora/v6/fla/layers/rwkv6.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										264
									
								
								finetune/lora/v6/fla/layers/rwkv6.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,264 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
from fla.modules import FusedLayerNormSwishGate, LayerNorm
 | 
			
		||||
from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKV6Attention(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        mode: str = 'chunk',
 | 
			
		||||
        hidden_size: int = 1024,
 | 
			
		||||
        expand_k: float = 0.5,
 | 
			
		||||
        expand_v: float = 1.0,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        gate_fn: str = 'swish',
 | 
			
		||||
        proj_low_rank_dim: int = 32,
 | 
			
		||||
        gate_low_rank_dim: int = 64,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-5,
 | 
			
		||||
        layer_idx: int = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> RWKV6Attention:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.proj_low_rank_dim = proj_low_rank_dim
 | 
			
		||||
        self.gate_low_rank_dim = gate_low_rank_dim
 | 
			
		||||
 | 
			
		||||
        self.key_dim = int(hidden_size * expand_k)
 | 
			
		||||
        self.value_dim = int(hidden_size * expand_v)
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
 | 
			
		||||
        assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
        assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
 | 
			
		||||
        self.head_qk_dim = self.key_dim // num_heads
 | 
			
		||||
        self.head_v_dim = self.value_dim // num_heads
 | 
			
		||||
 | 
			
		||||
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
			
		||||
        self.x_proj = nn.Sequential(
 | 
			
		||||
            LerpLinear(hidden_size, proj_low_rank_dim * 5),
 | 
			
		||||
            nn.Tanh(),
 | 
			
		||||
            nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=True)
 | 
			
		||||
        )
 | 
			
		||||
        self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
 | 
			
		||||
        self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
 | 
			
		||||
        self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
 | 
			
		||||
        self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
 | 
			
		||||
        self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
 | 
			
		||||
        self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_qk_dim))
 | 
			
		||||
 | 
			
		||||
        self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        if gate_fn == 'swish' and fuse_norm:
 | 
			
		||||
            self.g_norm_swish_gate = FusedLayerNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
            self.fuse_norm_and_gate = True
 | 
			
		||||
        else:
 | 
			
		||||
            self.fuse_norm_and_gate = False
 | 
			
		||||
            self.g_norm = LayerNorm(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
            self.gate_fn = ACT2FN[gate_fn]
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        if isinstance(module, nn.Parameter):
 | 
			
		||||
            nn.init.xavier_uniform_(module, gain=2 ** -2.5)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
 | 
			
		||||
        batch_size, seq_len, hidden_size = hidden_states.size()
 | 
			
		||||
        # launching the triton kernel for just one token will actually be slower
 | 
			
		||||
        mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
 | 
			
		||||
 | 
			
		||||
        delta = self.time_shift(hidden_states) - hidden_states
 | 
			
		||||
        x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
 | 
			
		||||
        r, w, k, v, g = torch.einsum('b l n r, n r d-> b l n d',
 | 
			
		||||
                                     self.x_proj[1](x),
 | 
			
		||||
                                     self.x_proj[2].weight.view(5, -1, hidden_size)).unbind(-2)
 | 
			
		||||
        r = self.r_proj(hidden_states, r, delta)
 | 
			
		||||
        w = self.w_proj(hidden_states, w, delta)
 | 
			
		||||
        k = self.k_proj(hidden_states, k, delta)
 | 
			
		||||
        v = self.v_proj(hidden_states, v, delta)
 | 
			
		||||
        g = self.g_proj(hidden_states, g, delta)
 | 
			
		||||
 | 
			
		||||
        # dealing with left-padding
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            v = v.mul_(attention_mask.unsqueeze(-1))
 | 
			
		||||
        r, w, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (r, w, k, v))
 | 
			
		||||
        w = -torch.exp(w)
 | 
			
		||||
        u = self.bonus
 | 
			
		||||
 | 
			
		||||
        last_state = past_key_values[self.layer_idx] if use_cache else None
 | 
			
		||||
        state = last_state[-1] if use_cache else None
 | 
			
		||||
        if mode == 'fused_recurrent':
 | 
			
		||||
            o, recurrent_state = fused_recurrent_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache)
 | 
			
		||||
        elif mode == 'chunk':
 | 
			
		||||
            o, recurrent_state = chunk_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Not supported mode `{mode}`.")
 | 
			
		||||
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            past_key_values.update((recurrent_state,), self.layer_idx, r.shape[2])
 | 
			
		||||
 | 
			
		||||
        o = rearrange(o, 'b h l d -> b l h d')
 | 
			
		||||
        if self.fuse_norm_and_gate:
 | 
			
		||||
            g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
 | 
			
		||||
            o = self.g_norm_swish_gate(o, g)
 | 
			
		||||
            o = rearrange(o, 'b l h d -> b l (h d)')
 | 
			
		||||
        else:
 | 
			
		||||
            o = self.g_norm(o)
 | 
			
		||||
            o = rearrange(o, 'b l h d -> b l (h d)')
 | 
			
		||||
            o = o * self.gate_fn(g)
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
 | 
			
		||||
        return o, None, past_key_values
 | 
			
		||||
 | 
			
		||||
    def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
 | 
			
		||||
        param = next(self.parameters())
 | 
			
		||||
        state = (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def state_size(self, **kwargs) -> int:
 | 
			
		||||
        state_size = self.key_dim * self.head_v_dim
 | 
			
		||||
        return state_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LoRA(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        input_dim: int,
 | 
			
		||||
        output_dim: int,
 | 
			
		||||
        low_rank_dim: int,
 | 
			
		||||
        bias: Optional[bool] = True
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.input_dim = input_dim
 | 
			
		||||
        self.output_dim = output_dim
 | 
			
		||||
        self.low_rank_dim = low_rank_dim
 | 
			
		||||
        self.bias = bias
 | 
			
		||||
 | 
			
		||||
        self.lora = nn.Sequential(
 | 
			
		||||
            nn.Linear(input_dim, low_rank_dim, bias=False),
 | 
			
		||||
            nn.Tanh(),
 | 
			
		||||
            nn.Linear(low_rank_dim, output_dim, bias=bias)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}("
 | 
			
		||||
        s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
 | 
			
		||||
        if not self.bias:
 | 
			
		||||
            s += f", bias={self.bias}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        return self.lora(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LerpLinear(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        input_dim: int,
 | 
			
		||||
        output_dim: int,
 | 
			
		||||
        low_rank_dim: Optional[int] = None
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.input_dim = input_dim
 | 
			
		||||
        self.output_dim = output_dim
 | 
			
		||||
        self.low_rank_dim = low_rank_dim
 | 
			
		||||
 | 
			
		||||
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
			
		||||
        if low_rank_dim is None:
 | 
			
		||||
            self.linear = nn.Linear(input_dim, output_dim, bias=False)
 | 
			
		||||
        else:
 | 
			
		||||
            self.linear = LoRA(input_dim, output_dim, low_rank_dim)
 | 
			
		||||
        self.mu = nn.Parameter(torch.zeros(input_dim))
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
 | 
			
		||||
        if self.low_rank_dim is not None:
 | 
			
		||||
            s += f", low_rank_dim={self.low_rank_dim}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
 | 
			
		||||
        if delta is None:
 | 
			
		||||
            shifted = self.time_shift(x)
 | 
			
		||||
            if len(shifted.shape) == 2:
 | 
			
		||||
                shifted = shifted.unsqueeze(1)
 | 
			
		||||
            delta = shifted - x
 | 
			
		||||
        return self.linear(x + delta * self.mu)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DDLerpLinear(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        input_dim: int,
 | 
			
		||||
        output_dim: int,
 | 
			
		||||
        low_rank_dim: Optional[int] = None
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.input_dim = input_dim
 | 
			
		||||
        self.output_dim = output_dim
 | 
			
		||||
        self.low_rank_dim = low_rank_dim
 | 
			
		||||
 | 
			
		||||
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
			
		||||
        if low_rank_dim is None:
 | 
			
		||||
            self.linear = nn.Linear(input_dim, output_dim, bias=False)
 | 
			
		||||
        else:
 | 
			
		||||
            self.linear = LoRA(input_dim, output_dim, low_rank_dim)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
 | 
			
		||||
        if self.low_rank_dim is not None:
 | 
			
		||||
            s += f", low_rank_dim={self.low_rank_dim}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
 | 
			
		||||
        if delta is None:
 | 
			
		||||
            shifted = self.time_shift(x)
 | 
			
		||||
            if len(shifted.shape) == 2:
 | 
			
		||||
                shifted = shifted.unsqueeze(1)
 | 
			
		||||
            delta = shifted - x
 | 
			
		||||
        return self.linear(x + delta * mu)
 | 
			
		||||
							
								
								
									
										143
									
								
								finetune/lora/v6/fla/layers/simple_gla.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								finetune/lora/v6/fla/layers/simple_gla.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,143 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
 | 
			
		||||
from fla.modules import FusedRMSNormSwishGate, RMSNorm
 | 
			
		||||
from fla.ops.simple_gla import chunk_simple_gla
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SimpleGatedLinearAttention(nn.Module):
 | 
			
		||||
    r"""
 | 
			
		||||
    The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635).  # noqa
 | 
			
		||||
    This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        mode (str, Optional):
 | 
			
		||||
            Which GLA kernel to use.
 | 
			
		||||
            Currently available: `chunk`.
 | 
			
		||||
            Default: `chunk`.
 | 
			
		||||
        hidden_size (int, Optional):
 | 
			
		||||
            The hidden size of the input. Default: 1024.
 | 
			
		||||
        expand_k (float, Optional):
 | 
			
		||||
            The expansion ratio for the key dim. Default: 0.5.
 | 
			
		||||
        expand_v (float, Optional):
 | 
			
		||||
            The expansion ratio for the value dim. Default: 1.0.
 | 
			
		||||
        num_heads (int, Optional):
 | 
			
		||||
            The number of heads. Default: 4.
 | 
			
		||||
        gate_fn (str, Optional):
 | 
			
		||||
            The activation function for the output gate. Default: `swish`.
 | 
			
		||||
        elementwise_affine (bool, Optional):
 | 
			
		||||
            If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
 | 
			
		||||
        norm_eps (float, Optional):
 | 
			
		||||
            The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
 | 
			
		||||
        gate_logit_normalizer (int, Optional):
 | 
			
		||||
            The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
 | 
			
		||||
        fuse_norm (bool, Optional):
 | 
			
		||||
            Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
 | 
			
		||||
        layer_idx (int, Optional):
 | 
			
		||||
            The index of the layer. Default: None.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        mode: str = 'chunk',
 | 
			
		||||
        hidden_size: int = 1024,
 | 
			
		||||
        expand_k: float = 1.0,
 | 
			
		||||
        expand_v: float = 2.0,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        gate_fn: str = 'swish',
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-5,
 | 
			
		||||
        gate_logit_normalizer: int = 16,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> SimpleGatedLinearAttention:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
 | 
			
		||||
        self.mode = mode
 | 
			
		||||
        self.key_dim = int(hidden_size * expand_k)
 | 
			
		||||
        self.value_dim = int(hidden_size * expand_v)
 | 
			
		||||
        assert mode in ['chunk'], f"Not suppoerted mode `{mode}`."
 | 
			
		||||
        assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
        assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.head_qk_dim = self.key_dim // num_heads
 | 
			
		||||
        self.head_v_dim = self.value_dim // num_heads
 | 
			
		||||
        self.gate_fn = ACT2FN[gate_fn]
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
 | 
			
		||||
        self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
 | 
			
		||||
 | 
			
		||||
        self.gk_proj = nn.Linear(hidden_size, self.num_heads)
 | 
			
		||||
        self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        if gate_fn == 'swish' and fuse_norm:
 | 
			
		||||
            self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
            self.fuse_norm_and_gate = True
 | 
			
		||||
        else:
 | 
			
		||||
            self.fuse_norm_and_gate = False
 | 
			
		||||
            self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
 | 
			
		||||
 | 
			
		||||
        self.gate_logit_normalizer = gate_logit_normalizer
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        mode = self.mode
 | 
			
		||||
        q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
        k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
        v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
 | 
			
		||||
        gk = rearrange(self.gk_proj(x), 'b n h -> b h n')
 | 
			
		||||
        gk = (F.logsigmoid(gk) / self.gate_logit_normalizer)
 | 
			
		||||
 | 
			
		||||
        if mode == 'chunk':
 | 
			
		||||
            o = chunk_simple_gla(q, k, v, gk)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(f"Not supported mode `{mode}`.")
 | 
			
		||||
 | 
			
		||||
        o = rearrange(o, 'b h l d -> b l h d')
 | 
			
		||||
        g = self.g_proj(x)
 | 
			
		||||
 | 
			
		||||
        if self.fuse_norm_and_gate:
 | 
			
		||||
            g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
 | 
			
		||||
            o = self.g_norm_swish_gate(o, g)
 | 
			
		||||
            o = rearrange(o, 'b l h d -> b l (h d)')
 | 
			
		||||
        else:
 | 
			
		||||
            o = self.g_norm(o)
 | 
			
		||||
            o = rearrange(o, 'b l h d -> b l (h d)')
 | 
			
		||||
            o = o * self.gate_fn(g)
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
        return o
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    batch = 4
 | 
			
		||||
    seq_len = 1024
 | 
			
		||||
 | 
			
		||||
    hidden_size = 2048
 | 
			
		||||
    x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
 | 
			
		||||
    model = SimpleGatedLinearAttention(hidden_size=hidden_size, mode='chunk').to(torch.bfloat16).cuda()
 | 
			
		||||
    y = model(x)
 | 
			
		||||
    print(y.shape)
 | 
			
		||||
    y.sum().backward()
 | 
			
		||||
    print(x.grad.shape)
 | 
			
		||||
							
								
								
									
										29
									
								
								finetune/lora/v6/fla/models/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								finetune/lora/v6/fla/models/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,29 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
 | 
			
		||||
from fla.models.delta_net import (DeltaNetConfig, DeltaNetForCausalLM,
 | 
			
		||||
                                  DeltaNetModel)
 | 
			
		||||
from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
 | 
			
		||||
from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel
 | 
			
		||||
from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model
 | 
			
		||||
from fla.models.linear_attn import (LinearAttentionConfig,
 | 
			
		||||
                                    LinearAttentionForCausalLM,
 | 
			
		||||
                                    LinearAttentionModel)
 | 
			
		||||
from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel
 | 
			
		||||
from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
 | 
			
		||||
from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
 | 
			
		||||
from fla.models.transformer import (TransformerConfig, TransformerForCausalLM,
 | 
			
		||||
                                    TransformerModel)
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'ABCConfig', 'ABCForCausalLM', 'ABCModel',
 | 
			
		||||
    'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
 | 
			
		||||
    'GLAConfig', 'GLAForCausalLM', 'GLAModel',
 | 
			
		||||
    'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel',
 | 
			
		||||
    'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model',
 | 
			
		||||
    'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
 | 
			
		||||
    'MambaConfig', 'MambaForCausalLM', 'MambaModel',
 | 
			
		||||
    'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
 | 
			
		||||
    'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
 | 
			
		||||
    'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										13
									
								
								finetune/lora/v6/fla/models/abc/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								finetune/lora/v6/fla/models/abc/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.abc.configuration_abc import ABCConfig
 | 
			
		||||
from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(ABCConfig.model_type, ABCConfig)
 | 
			
		||||
AutoModel.register(ABCConfig, ABCModel)
 | 
			
		||||
AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']
 | 
			
		||||
							
								
								
									
										74
									
								
								finetune/lora/v6/fla/models/abc/configuration_abc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								finetune/lora/v6/fla/models/abc/configuration_abc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,74 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ABCConfig(PretrainedConfig):
 | 
			
		||||
 | 
			
		||||
    model_type = 'abc'
 | 
			
		||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vocab_size: int = 32000,
 | 
			
		||||
        hidden_size: int = 2048,
 | 
			
		||||
        gate_low_rank_dim: int = 16,
 | 
			
		||||
        clamp_min: float = -32,
 | 
			
		||||
        clamp_max: float = 32,
 | 
			
		||||
        hidden_ratio: Optional[int] = 4,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        num_hidden_layers: int = 24,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        num_slots: Optional[int] = 64,
 | 
			
		||||
        use_short_conv: bool = True,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        exapnd_k: float = 0.5,
 | 
			
		||||
        exapnd_v: float = 1,
 | 
			
		||||
        hidden_act: str = "swish",
 | 
			
		||||
        max_position_embeddings: int = 2048,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-6,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
        pad_token_id: int = None,
 | 
			
		||||
        bos_token_id: int = 1,
 | 
			
		||||
        eos_token_id: int = 2,
 | 
			
		||||
        initializer_range: float = 0.02,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.gate_low_rank_dim = gate_low_rank_dim
 | 
			
		||||
        self.clamp_min = clamp_min
 | 
			
		||||
        self.clamp_max = clamp_max
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.num_slots = num_slots
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
        self.expand_k = exapnd_k
 | 
			
		||||
        self.expand_v = exapnd_v
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.norm_eps = norm_eps
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
        self.fuse_norm = fuse_norm
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            pad_token_id=pad_token_id,
 | 
			
		||||
            bos_token_id=bos_token_id,
 | 
			
		||||
            eos_token_id=eos_token_id,
 | 
			
		||||
            tie_word_embeddings=tie_word_embeddings,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										394
									
								
								finetune/lora/v6/fla/models/abc/modeling_abc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										394
									
								
								finetune/lora/v6/fla/models/abc/modeling_abc.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,394 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
 | 
			
		||||
                                           CausalLMOutputWithPast)
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from fla.layers.abc import ABCAttention
 | 
			
		||||
from fla.models.abc.configuration_abc import ABCConfig
 | 
			
		||||
from fla.models.utils import RecurrentCache
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, RMSNorm
 | 
			
		||||
from fla.modules.activations import swiglu_linear
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ABCMLP(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'swish'
 | 
			
		||||
    ) -> ABCMLP:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        # the final number of params is `hidden_ratio * hidden_size^2`
 | 
			
		||||
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 4
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
 | 
			
		||||
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
 | 
			
		||||
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
 | 
			
		||||
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        y = self.gate_proj(x)
 | 
			
		||||
        gate, y = y.chunk(2, -1)
 | 
			
		||||
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ABCBlock(nn.Module):
 | 
			
		||||
    def __init__(self, config: ABCConfig, layer_idx: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.attn = ABCAttention(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            expand_k=config.expand_k,
 | 
			
		||||
            expand_v=config.expand_v,
 | 
			
		||||
            num_heads=config.num_heads,
 | 
			
		||||
            num_slots=config.num_slots,
 | 
			
		||||
            use_short_conv=config.use_short_conv,
 | 
			
		||||
            conv_size=config.conv_size,
 | 
			
		||||
            share_conv_kernel=config.share_conv_kernel,
 | 
			
		||||
            gate_fn=config.hidden_act,
 | 
			
		||||
            elementwise_affine=config.elementwise_affine,
 | 
			
		||||
            norm_eps=config.norm_eps,
 | 
			
		||||
            clamp_min=config.clamp_min,
 | 
			
		||||
            clamp_max=config.clamp_max,
 | 
			
		||||
            fuse_norm=config.fuse_norm,
 | 
			
		||||
            layer_idx=layer_idx
 | 
			
		||||
        )
 | 
			
		||||
        self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.mlp = ABCMLP(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            hidden_ratio=config.hidden_ratio,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            hidden_act=config.hidden_act
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.attn_norm(hidden_states)
 | 
			
		||||
        hidden_states, attentions, past_key_values = self.attn(
 | 
			
		||||
            hidden_states=hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states, attentions, past_key_values)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ABCPreTrainedModel(PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    config_class = ABCConfig
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ['ABCBlock']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _init_weights(
 | 
			
		||||
        self,
 | 
			
		||||
        module: nn.Module,
 | 
			
		||||
        rescale_prenorm_residual: bool = True,
 | 
			
		||||
        num_residuals_per_layer: int = 2,
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv1d)):
 | 
			
		||||
            # Slightly different from the TF version which uses truncated_normal for initialization
 | 
			
		||||
            # cf https://github.com/pytorch/pytorch/pull/5617
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.padding_idx is not None:
 | 
			
		||||
                module.weight.data[module.padding_idx].zero_()
 | 
			
		||||
 | 
			
		||||
        if rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["o_proj.weight", "down_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ABCModel(ABCPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: ABCConfig):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,  # noqa
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.")
 | 
			
		||||
            output_attentions = False
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            batch_size = input_ids.shape[0]
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            batch_size = inputs_embeds.shape[0]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if past_key_values is None:
 | 
			
		||||
                past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_attns = () if output_attentions else None
 | 
			
		||||
        for layer in self.layers:
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
 | 
			
		||||
                    layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    use_cache,
 | 
			
		||||
                    output_attentions
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states, attentions, past_key_values = layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    past_key_values=past_key_values,
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                    output_attentions=output_attentions
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_attns += (attentions,)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = None
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_cache = past_key_values.to_legacy_cache()
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_attns
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ABCForCausalLM(ABCPreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.model = ABCModel(config)
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def generate(self, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return super().generate(*args, **kwargs)
 | 
			
		||||
        except AttributeError as exception:
 | 
			
		||||
            if 'past_key_values' in str(exception):
 | 
			
		||||
                raise AttributeError(
 | 
			
		||||
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
 | 
			
		||||
                    f"which is not supported for {self.__class__.__name__}. "
 | 
			
		||||
                    f"Try another generation strategy instead. "
 | 
			
		||||
                    f"For the available generation strategies, check this doc: "
 | 
			
		||||
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise exception
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for `inputs_ids` if the `past_key_values` is passed along.
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
 | 
			
		||||
            input_ids = input_ids[:, -1:]
 | 
			
		||||
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {'inputs_embeds': inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            model_inputs = {'input_ids': input_ids}
 | 
			
		||||
        model_inputs['past_key_values'] = past_key_values
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            past_key_values=outputs.past_key_values,
 | 
			
		||||
            hidden_states=outputs.hidden_states,
 | 
			
		||||
            attentions=outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										14
									
								
								finetune/lora/v6/fla/models/delta_net/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								finetune/lora/v6/fla/models/delta_net/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.delta_net.configuration_delta_net import \
 | 
			
		||||
    DeltaNetConfig
 | 
			
		||||
from fla.models.delta_net.modeling_delta_net import (
 | 
			
		||||
    DeltaNetForCausalLM, DeltaNetModel)
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
 | 
			
		||||
AutoModel.register(DeltaNetConfig, DeltaNetModel)
 | 
			
		||||
AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
 | 
			
		||||
 | 
			
		||||
__all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']
 | 
			
		||||
							
								
								
									
										77
									
								
								finetune/lora/v6/fla/models/delta_net/configuration_delta_net.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								finetune/lora/v6/fla/models/delta_net/configuration_delta_net.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,77 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeltaNetConfig(PretrainedConfig):
 | 
			
		||||
 | 
			
		||||
    model_type = 'delta_net'
 | 
			
		||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vocab_size: int = 32000,
 | 
			
		||||
        hidden_size: int = 2048,
 | 
			
		||||
        expand_k: int = 1,
 | 
			
		||||
        expand_v: int = 1,
 | 
			
		||||
        use_gate: bool = False,
 | 
			
		||||
        use_short_conv: bool = True,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        share_conv_kernel: bool = False,
 | 
			
		||||
        use_rope: bool = False,
 | 
			
		||||
        use_beta: bool = True,
 | 
			
		||||
        use_output_norm: bool = True,
 | 
			
		||||
        hidden_ratio: Optional[int] = 4,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        num_hidden_layers: int = 24,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        attn_mode: str = "chunk",
 | 
			
		||||
        qk_norm: str = 'l2',
 | 
			
		||||
        qk_activation: str = 'silu',
 | 
			
		||||
        chunk_size: int = 64,
 | 
			
		||||
        hidden_act: str = "swish",
 | 
			
		||||
        max_position_embeddings: int = 2048,
 | 
			
		||||
        rms_norm_eps: float = 1e-6,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
        pad_token_id: int = None,
 | 
			
		||||
        bos_token_id: int = 1,
 | 
			
		||||
        eos_token_id: int = 2,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        initializer_range: float = 0.02,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.attn_mode = attn_mode
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.rms_norm_eps = rms_norm_eps
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
        self.use_gate = use_gate
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
        self.use_rope = use_rope
 | 
			
		||||
        self.use_beta = use_beta
 | 
			
		||||
        self.use_output_norm = use_output_norm
 | 
			
		||||
        self.qk_norm = qk_norm
 | 
			
		||||
        self.qk_activation = qk_activation
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            pad_token_id=pad_token_id,
 | 
			
		||||
            bos_token_id=bos_token_id,
 | 
			
		||||
            eos_token_id=eos_token_id,
 | 
			
		||||
            tie_word_embeddings=tie_word_embeddings,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										405
									
								
								finetune/lora/v6/fla/models/delta_net/modeling_delta_net.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										405
									
								
								finetune/lora/v6/fla/models/delta_net/modeling_delta_net.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,405 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
 | 
			
		||||
                                           CausalLMOutputWithPast)
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from fla.layers.delta_net import DeltaNet
 | 
			
		||||
from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
 | 
			
		||||
from fla.models.utils import RecurrentCache
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, RMSNorm
 | 
			
		||||
from fla.modules.activations import swiglu_linear
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeltaNetMLP(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'swish'
 | 
			
		||||
    ) -> DeltaNetMLP:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        # the final number of params is `hidden_ratio * hidden_size^2`
 | 
			
		||||
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 4
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
 | 
			
		||||
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
 | 
			
		||||
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
 | 
			
		||||
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        y = self.gate_proj(x)
 | 
			
		||||
        gate, y = y.chunk(2, -1)
 | 
			
		||||
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeltaNetBlock(nn.Module):
 | 
			
		||||
    def __init__(self, config: DeltaNetConfig, layer_idx: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
 | 
			
		||||
        self.attn = DeltaNet(
 | 
			
		||||
            mode=config.attn_mode,
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            expand_k=config.expand_k,
 | 
			
		||||
            expand_v=config.expand_v,
 | 
			
		||||
            num_heads=config.num_heads,
 | 
			
		||||
            use_gate=config.use_gate,
 | 
			
		||||
            use_rope=config.use_rope,
 | 
			
		||||
            use_beta=config.use_beta,
 | 
			
		||||
            use_short_conv=config.use_short_conv,
 | 
			
		||||
            use_output_norm=config.use_output_norm,
 | 
			
		||||
            conv_size=config.conv_size,
 | 
			
		||||
            share_conv_kernel=config.share_conv_kernel,
 | 
			
		||||
            layer_idx=layer_idx,
 | 
			
		||||
            qk_norm=config.qk_norm,
 | 
			
		||||
            qk_activation=config.qk_activation
 | 
			
		||||
        )
 | 
			
		||||
        self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
 | 
			
		||||
        self.mlp = DeltaNetMLP(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            hidden_ratio=config.hidden_ratio,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            hidden_act=config.hidden_act
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.attn_norm(hidden_states)
 | 
			
		||||
        hidden_states, attentions, past_key_values = self.attn(
 | 
			
		||||
            hidden_states=hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states, attentions, past_key_values)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeltaNetPreTrainedModel(PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    config_class = DeltaNetConfig
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ['DeltaNetBlock']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _init_weights(
 | 
			
		||||
        self,
 | 
			
		||||
        module: nn.Module,
 | 
			
		||||
        rescale_prenorm_residual: bool = True,
 | 
			
		||||
        num_residuals_per_layer: int = 2,
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv1d)):
 | 
			
		||||
            # Slightly different from the TF version which uses truncated_normal for initialization
 | 
			
		||||
            # cf https://github.com/pytorch/pytorch/pull/5617
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.padding_idx is not None:
 | 
			
		||||
                module.weight.data[module.padding_idx].zero_()
 | 
			
		||||
 | 
			
		||||
        if rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["o_proj.weight", "down_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeltaNetModel(DeltaNetPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: DeltaNetConfig):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        self.layers = nn.ModuleList([DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,  # noqa
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            warnings.warn("`DeltaNetModel` does not `output_attentions` now, setting it to `False`.")
 | 
			
		||||
            output_attentions = False
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            batch_size = input_ids.shape[0]
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            batch_size = inputs_embeds.shape[0]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if past_key_values is None:
 | 
			
		||||
                past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_attns = () if output_attentions else None
 | 
			
		||||
        for layer in self.layers:
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
 | 
			
		||||
                    layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    use_cache,
 | 
			
		||||
                    output_attentions
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states, attentions, past_key_values = layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=attention_mask,
 | 
			
		||||
                    past_key_values=past_key_values,
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                    output_attentions=output_attentions
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_attns += (attentions,)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = past_key_values
 | 
			
		||||
        # if use_cache:
 | 
			
		||||
            # next_cache = past_key_values.to_legacy_cache()
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_attns
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeltaNetForCausalLM(DeltaNetPreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.model = DeltaNetModel(config)
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def generate(self, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return super().generate(*args, **kwargs)
 | 
			
		||||
        except AttributeError as exception:
 | 
			
		||||
            if 'past_key_values' in str(exception):
 | 
			
		||||
                raise AttributeError(
 | 
			
		||||
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
 | 
			
		||||
                    f"which is not supported for {self.__class__.__name__}. "
 | 
			
		||||
                    f"Try another generation strategy instead. "
 | 
			
		||||
                    f"For the available generation strategies, check this doc: "
 | 
			
		||||
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise exception
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for `inputs_ids` if the `past_key_values` is passed along.
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
 | 
			
		||||
            # breakpoint()
 | 
			
		||||
            input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
 | 
			
		||||
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {'inputs_embeds': inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
 | 
			
		||||
            # recompiles graphs as the stride of the inputs is a guard.
 | 
			
		||||
            # Ref: https://github.com/huggingface/transformers/pull/29114
 | 
			
		||||
            # TODO: use `next_tokens` directly instead.
 | 
			
		||||
            model_inputs = {'input_ids': input_ids.contiguous()}
 | 
			
		||||
        
 | 
			
		||||
        model_inputs.update({
 | 
			
		||||
            'past_key_values': past_key_values,
 | 
			
		||||
            'use_cache': kwargs.get('use_cache'),
 | 
			
		||||
            'attention_mask': attention_mask,
 | 
			
		||||
        })
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            past_key_values=outputs.past_key_values,
 | 
			
		||||
            hidden_states=outputs.hidden_states,
 | 
			
		||||
            attentions=outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										13
									
								
								finetune/lora/v6/fla/models/gla/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								finetune/lora/v6/fla/models/gla/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.gla.configuration_gla import GLAConfig
 | 
			
		||||
from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(GLAConfig.model_type, GLAConfig)
 | 
			
		||||
AutoModel.register(GLAConfig, GLAModel)
 | 
			
		||||
AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']
 | 
			
		||||
							
								
								
									
										80
									
								
								finetune/lora/v6/fla/models/gla/configuration_gla.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								finetune/lora/v6/fla/models/gla/configuration_gla.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,80 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GLAConfig(PretrainedConfig):
 | 
			
		||||
 | 
			
		||||
    model_type = 'gla'
 | 
			
		||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vocab_size: int = 32000,
 | 
			
		||||
        hidden_size: int = 2048,
 | 
			
		||||
        expand_k: int = 0.5,
 | 
			
		||||
        expand_v: int = 1,
 | 
			
		||||
        hidden_ratio: Optional[int] = 4,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        num_hidden_layers: int = 24,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        num_kv_heads: Optional[int] = None,
 | 
			
		||||
        feature_map: Optional[str] = None,
 | 
			
		||||
        attn_mode: str = "chunk",
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        use_output_gate: bool = True,
 | 
			
		||||
        clamp_min: Optional[float] = None,
 | 
			
		||||
        hidden_act: str = "swish",
 | 
			
		||||
        max_position_embeddings: int = 2048,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-6,
 | 
			
		||||
        use_gk: bool = True,
 | 
			
		||||
        use_gv: bool = False,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
        pad_token_id: int = None,
 | 
			
		||||
        bos_token_id: int = 1,
 | 
			
		||||
        eos_token_id: int = 2,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        initializer_range: float = 0.02,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.num_kv_heads = num_kv_heads
 | 
			
		||||
        self.feature_map = feature_map
 | 
			
		||||
        self.attn_mode = attn_mode
 | 
			
		||||
        self.clamp_min = clamp_min
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.norm_eps = norm_eps
 | 
			
		||||
        self.use_gk = use_gk
 | 
			
		||||
        self.use_gv = use_gv
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.fuse_norm = fuse_norm
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
        self.use_output_gate = use_output_gate
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            pad_token_id=pad_token_id,
 | 
			
		||||
            bos_token_id=bos_token_id,
 | 
			
		||||
            eos_token_id=eos_token_id,
 | 
			
		||||
            tie_word_embeddings=tie_word_embeddings,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										403
									
								
								finetune/lora/v6/fla/models/gla/modeling_gla.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										403
									
								
								finetune/lora/v6/fla/models/gla/modeling_gla.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,403 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
 | 
			
		||||
                                           CausalLMOutputWithPast)
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from fla.layers.gla import GatedLinearAttention
 | 
			
		||||
from fla.models.gla.configuration_gla import GLAConfig
 | 
			
		||||
from fla.models.utils import RecurrentCache
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, RMSNorm
 | 
			
		||||
from fla.modules.activations import swiglu_linear
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GLAMLP(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'swish'
 | 
			
		||||
    ) -> GLAMLP:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        # the final number of params is `hidden_ratio * hidden_size^2`
 | 
			
		||||
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 4
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
 | 
			
		||||
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
 | 
			
		||||
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
 | 
			
		||||
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        y = self.gate_proj(x)
 | 
			
		||||
        gate, y = y.chunk(2, -1)
 | 
			
		||||
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GLABlock(nn.Module):
 | 
			
		||||
    def __init__(self, config: GLAConfig, layer_idx: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.attn = GatedLinearAttention(
 | 
			
		||||
            mode=config.attn_mode,
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            expand_k=config.expand_k,
 | 
			
		||||
            expand_v=config.expand_v,
 | 
			
		||||
            num_heads=config.num_heads,
 | 
			
		||||
            num_kv_heads=config.num_kv_heads,
 | 
			
		||||
            feature_map=config.feature_map,
 | 
			
		||||
            use_short_conv=config.use_short_conv,
 | 
			
		||||
            conv_size=config.conv_size,
 | 
			
		||||
            share_conv_kernel=config.share_conv_kernel,
 | 
			
		||||
            use_output_gate=config.use_output_gate,
 | 
			
		||||
            gate_fn=config.hidden_act,
 | 
			
		||||
            elementwise_affine=config.elementwise_affine,
 | 
			
		||||
            norm_eps=config.norm_eps,
 | 
			
		||||
            clamp_min=config.clamp_min,
 | 
			
		||||
            fuse_norm=config.fuse_norm,
 | 
			
		||||
            layer_idx=layer_idx
 | 
			
		||||
        )
 | 
			
		||||
        self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.mlp = GLAMLP(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            hidden_ratio=config.hidden_ratio,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            hidden_act=config.hidden_act
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        hidden_states = self.attn_norm(hidden_states)
 | 
			
		||||
        hidden_states, attentions, past_key_values = self.attn(
 | 
			
		||||
            hidden_states=hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states, attentions, past_key_values)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GLAPreTrainedModel(PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    config_class = GLAConfig
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ['GLABlock']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _init_weights(
 | 
			
		||||
        self,
 | 
			
		||||
        module: nn.Module,
 | 
			
		||||
        rescale_prenorm_residual: bool = True,
 | 
			
		||||
        num_residuals_per_layer: int = 2,
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv1d)):
 | 
			
		||||
            # Slightly different from the TF version which uses truncated_normal for initialization
 | 
			
		||||
            # cf https://github.com/pytorch/pytorch/pull/5617
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.padding_idx is not None:
 | 
			
		||||
                module.weight.data[module.padding_idx].zero_()
 | 
			
		||||
 | 
			
		||||
        if rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["o_proj.weight", "down_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GLAModel(GLAPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: GLAConfig):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,  # noqa
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.")
 | 
			
		||||
            output_attentions = False
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            batch_size = input_ids.shape[0]
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            batch_size = inputs_embeds.shape[0]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if past_key_values is None:
 | 
			
		||||
                past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_attns = () if output_attentions else None
 | 
			
		||||
        for layer in self.layers:
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
 | 
			
		||||
                    layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    use_cache,
 | 
			
		||||
                    output_attentions
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states, attentions, past_key_values = layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=attention_mask,
 | 
			
		||||
                    past_key_values=past_key_values,
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                    output_attentions=output_attentions
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_attns += (attentions,)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = None
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_cache = past_key_values.to_legacy_cache()
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_attns
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GLAForCausalLM(GLAPreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.model = GLAModel(config)
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def generate(self, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return super().generate(*args, **kwargs)
 | 
			
		||||
        except AttributeError as exception:
 | 
			
		||||
            if 'past_key_values' in str(exception):
 | 
			
		||||
                raise AttributeError(
 | 
			
		||||
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
 | 
			
		||||
                    f"which is not supported for {self.__class__.__name__}. "
 | 
			
		||||
                    f"Try another generation strategy instead. "
 | 
			
		||||
                    f"For the available generation strategies, check this doc: "
 | 
			
		||||
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise exception
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for `inputs_ids` if the `past_key_values` is passed along.
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
 | 
			
		||||
            input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {'inputs_embeds': inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
 | 
			
		||||
            # recompiles graphs as the stride of the inputs is a guard.
 | 
			
		||||
            # Ref: https://github.com/huggingface/transformers/pull/29114
 | 
			
		||||
            # TODO: use `next_tokens` directly instead.
 | 
			
		||||
            model_inputs = {'input_ids': input_ids.contiguous()}
 | 
			
		||||
 | 
			
		||||
        model_inputs.update({
 | 
			
		||||
            'past_key_values': past_key_values,
 | 
			
		||||
            'use_cache': kwargs.get('use_cache'),
 | 
			
		||||
            'attention_mask': attention_mask,
 | 
			
		||||
        })
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            past_key_values=outputs.past_key_values,
 | 
			
		||||
            hidden_states=outputs.hidden_states,
 | 
			
		||||
            attentions=outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										13
									
								
								finetune/lora/v6/fla/models/hgrn/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								finetune/lora/v6/fla/models/hgrn/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.hgrn.configuration_hgrn import HGRNConfig
 | 
			
		||||
from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(HGRNConfig.model_type, HGRNConfig)
 | 
			
		||||
AutoModel.register(HGRNConfig, HGRNModel)
 | 
			
		||||
AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel']
 | 
			
		||||
							
								
								
									
										66
									
								
								finetune/lora/v6/fla/models/hgrn/configuration_hgrn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								finetune/lora/v6/fla/models/hgrn/configuration_hgrn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,66 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRNConfig(PretrainedConfig):
 | 
			
		||||
 | 
			
		||||
    model_type = 'hgrn'
 | 
			
		||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        attn_mode: str = "chunk",
 | 
			
		||||
        vocab_size: int = 32000,
 | 
			
		||||
        hidden_size: int = 2048,
 | 
			
		||||
        num_hidden_layers: int = 24,
 | 
			
		||||
        num_heads: Optional[int] = 1,
 | 
			
		||||
        expand_ratio: Optional[int] = 1,
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        use_lower_bound: bool = True,
 | 
			
		||||
        hidden_ratio: Optional[int] = 4,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = "swish",
 | 
			
		||||
        max_position_embeddings: int = 2048,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-6,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
        pad_token_id: int = None,
 | 
			
		||||
        bos_token_id: int = 1,
 | 
			
		||||
        eos_token_id: int = 2,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        initializer_range: float = 0.02,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        self.attn_mode = attn_mode
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.expand_ratio = expand_ratio
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
        self.use_lower_bound = use_lower_bound
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.norm_eps = norm_eps
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            pad_token_id=pad_token_id,
 | 
			
		||||
            bos_token_id=bos_token_id,
 | 
			
		||||
            eos_token_id=eos_token_id,
 | 
			
		||||
            tie_word_embeddings=tie_word_embeddings,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										407
									
								
								finetune/lora/v6/fla/models/hgrn/modeling_hgrn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										407
									
								
								finetune/lora/v6/fla/models/hgrn/modeling_hgrn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,407 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
 | 
			
		||||
                                           CausalLMOutputWithPast)
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from fla.layers.hgrn import HGRNAttention
 | 
			
		||||
from fla.models.hgrn.configuration_hgrn import HGRNConfig
 | 
			
		||||
from fla.models.utils import RecurrentCache
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, RMSNorm
 | 
			
		||||
from fla.modules.activations import swiglu_linear
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRNMLP(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'swish'
 | 
			
		||||
    ) -> HGRNMLP:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        # the final number of params is `hidden_ratio * hidden_size^2`
 | 
			
		||||
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 4
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
 | 
			
		||||
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
 | 
			
		||||
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
 | 
			
		||||
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        y = self.gate_proj(x)
 | 
			
		||||
        gate, y = y.chunk(2, -1)
 | 
			
		||||
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRNBlock(nn.Module):
 | 
			
		||||
    def __init__(self, config: HGRNConfig, layer_idx: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.attn = HGRNAttention(
 | 
			
		||||
            mode=config.attn_mode,
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            num_heads=config.num_heads,
 | 
			
		||||
            expand_ratio=config.expand_ratio,
 | 
			
		||||
            use_short_conv=config.use_short_conv,
 | 
			
		||||
            conv_size=config.conv_size,
 | 
			
		||||
            share_conv_kernel=config.share_conv_kernel,
 | 
			
		||||
            elementwise_affine=config.elementwise_affine,
 | 
			
		||||
            norm_eps=config.norm_eps,
 | 
			
		||||
            layer_idx=layer_idx
 | 
			
		||||
        )
 | 
			
		||||
        self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.mlp = HGRNMLP(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            hidden_ratio=config.hidden_ratio,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            hidden_act=config.hidden_act
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        lower_bound: Optional[torch.Tensor] = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        hidden_states = self.attn_norm(hidden_states)
 | 
			
		||||
        hidden_states, attentions, past_key_values = self.attn(
 | 
			
		||||
            hidden_states=hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            lower_bound=lower_bound
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states, attentions, past_key_values)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRNPreTrainedModel(PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    config_class = HGRNConfig
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ['HGRNBlock']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _init_weights(
 | 
			
		||||
        self,
 | 
			
		||||
        module: nn.Module,
 | 
			
		||||
        rescale_prenorm_residual: bool = True,
 | 
			
		||||
        num_residuals_per_layer: int = 2,
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv1d)):
 | 
			
		||||
            # Slightly different from the TF version which uses truncated_normal for initialization
 | 
			
		||||
            # cf https://github.com/pytorch/pytorch/pull/5617
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.padding_idx is not None:
 | 
			
		||||
                module.weight.data[module.padding_idx].zero_()
 | 
			
		||||
 | 
			
		||||
        if rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["o_proj.weight", "down_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRNModel(HGRNPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: HGRNConfig):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        if config.use_lower_bound:
 | 
			
		||||
            self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
 | 
			
		||||
        self.layers = nn.ModuleList([HGRNBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,  # noqa
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            warnings.warn("`HGRNModel` does not `output_attentions` now, setting it to `False`.")
 | 
			
		||||
            output_attentions = False
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            batch_size = input_ids.shape[0]
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            batch_size = inputs_embeds.shape[0]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if past_key_values is None:
 | 
			
		||||
                past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_attns = () if output_attentions else None
 | 
			
		||||
 | 
			
		||||
        if self.config.use_lower_bound:
 | 
			
		||||
            lower_bounds = self.lower_bounds.softmax(0)
 | 
			
		||||
            lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
 | 
			
		||||
        for i, layer in enumerate(self.layers):
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
 | 
			
		||||
                    layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    use_cache,
 | 
			
		||||
                    output_attentions,
 | 
			
		||||
                    lower_bound
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states, attentions, past_key_values = layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=attention_mask,
 | 
			
		||||
                    past_key_values=past_key_values,
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                    output_attentions=output_attentions,
 | 
			
		||||
                    lower_bound=lower_bound
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_attns += (attentions,)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = None
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_cache = past_key_values.to_legacy_cache()
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_attns
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRNForCausalLM(HGRNPreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.model = HGRNModel(config)
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def generate(self, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return super().generate(*args, **kwargs)
 | 
			
		||||
        except AttributeError as exception:
 | 
			
		||||
            if 'past_key_values' in str(exception):
 | 
			
		||||
                raise AttributeError(
 | 
			
		||||
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
 | 
			
		||||
                    f"which is not supported for {self.__class__.__name__}. "
 | 
			
		||||
                    f"Try another generation strategy instead. "
 | 
			
		||||
                    f"For the available generation strategies, check this doc: "
 | 
			
		||||
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise exception
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for `inputs_ids` if the `past_key_values` is passed along.
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
 | 
			
		||||
            input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {'inputs_embeds': inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
 | 
			
		||||
            # recompiles graphs as the stride of the inputs is a guard.
 | 
			
		||||
            # Ref: https://github.com/huggingface/transformers/pull/29114
 | 
			
		||||
            # TODO: use `next_tokens` directly instead.
 | 
			
		||||
            model_inputs = {'input_ids': input_ids.contiguous()}
 | 
			
		||||
 | 
			
		||||
        model_inputs.update({
 | 
			
		||||
            'past_key_values': past_key_values,
 | 
			
		||||
            'use_cache': kwargs.get('use_cache'),
 | 
			
		||||
            'attention_mask': attention_mask,
 | 
			
		||||
        })
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            past_key_values=outputs.past_key_values,
 | 
			
		||||
            hidden_states=outputs.hidden_states,
 | 
			
		||||
            attentions=outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										13
									
								
								finetune/lora/v6/fla/models/hgrn2/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								finetune/lora/v6/fla/models/hgrn2/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
 | 
			
		||||
from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(HGRN2Config.model_type, HGRN2Config)
 | 
			
		||||
AutoModel.register(HGRN2Config, HGRN2Model)
 | 
			
		||||
AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model']
 | 
			
		||||
							
								
								
									
										66
									
								
								finetune/lora/v6/fla/models/hgrn2/configuration_hgrn2.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								finetune/lora/v6/fla/models/hgrn2/configuration_hgrn2.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,66 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRN2Config(PretrainedConfig):
 | 
			
		||||
 | 
			
		||||
    model_type = 'hgrn2'
 | 
			
		||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vocab_size: int = 32000,
 | 
			
		||||
        hidden_size: int = 2048,
 | 
			
		||||
        num_hidden_layers: int = 24,
 | 
			
		||||
        attn_mode: str = "chunk",
 | 
			
		||||
        num_heads: Optional[int] = None,
 | 
			
		||||
        expand_ratio: Optional[int] = 128,
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        use_lower_bound: bool = True,
 | 
			
		||||
        hidden_ratio: Optional[int] = 4,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = "swish",
 | 
			
		||||
        max_position_embeddings: int = 2048,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-6,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
        pad_token_id: int = None,
 | 
			
		||||
        bos_token_id: int = 1,
 | 
			
		||||
        eos_token_id: int = 2,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        initializer_range: float = 0.02,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.attn_mode = attn_mode
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.expand_ratio = expand_ratio
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
        self.use_lower_bound = use_lower_bound
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.norm_eps = norm_eps
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            pad_token_id=pad_token_id,
 | 
			
		||||
            bos_token_id=bos_token_id,
 | 
			
		||||
            eos_token_id=eos_token_id,
 | 
			
		||||
            tie_word_embeddings=tie_word_embeddings,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										407
									
								
								finetune/lora/v6/fla/models/hgrn2/modeling_hgrn2.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										407
									
								
								finetune/lora/v6/fla/models/hgrn2/modeling_hgrn2.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,407 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
 | 
			
		||||
                                           CausalLMOutputWithPast)
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from fla.layers.hgrn2 import HGRN2Attention
 | 
			
		||||
from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
 | 
			
		||||
from fla.models.utils import RecurrentCache
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, RMSNorm
 | 
			
		||||
from fla.modules.activations import swiglu_linear
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRN2MLP(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'swish'
 | 
			
		||||
    ) -> HGRN2MLP:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        # the final number of params is `hidden_ratio * hidden_size^2`
 | 
			
		||||
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 4
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
 | 
			
		||||
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
 | 
			
		||||
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
 | 
			
		||||
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        y = self.gate_proj(x)
 | 
			
		||||
        gate, y = y.chunk(2, -1)
 | 
			
		||||
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRN2Block(nn.Module):
 | 
			
		||||
    def __init__(self, config: HGRN2Config, layer_idx: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.attn = HGRN2Attention(
 | 
			
		||||
            mode=config.attn_mode,
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            num_heads=config.num_heads,
 | 
			
		||||
            expand_ratio=config.expand_ratio,
 | 
			
		||||
            use_short_conv=config.use_short_conv,
 | 
			
		||||
            conv_size=config.conv_size,
 | 
			
		||||
            share_conv_kernel=config.share_conv_kernel,
 | 
			
		||||
            elementwise_affine=config.elementwise_affine,
 | 
			
		||||
            norm_eps=config.norm_eps,
 | 
			
		||||
            layer_idx=layer_idx
 | 
			
		||||
        )
 | 
			
		||||
        self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.mlp = HGRN2MLP(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            hidden_ratio=config.hidden_ratio,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            hidden_act=config.hidden_act
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        lower_bound: Optional[torch.Tensor] = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        hidden_states = self.attn_norm(hidden_states)
 | 
			
		||||
        hidden_states, attentions, past_key_values = self.attn(
 | 
			
		||||
            hidden_states=hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            lower_bound=lower_bound
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states, attentions, past_key_values)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRN2PreTrainedModel(PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    config_class = HGRN2Config
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ['HGRN2Block']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _init_weights(
 | 
			
		||||
        self,
 | 
			
		||||
        module: nn.Module,
 | 
			
		||||
        rescale_prenorm_residual: bool = True,
 | 
			
		||||
        num_residuals_per_layer: int = 2,
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv1d)):
 | 
			
		||||
            # Slightly different from the TF version which uses truncated_normal for initialization
 | 
			
		||||
            # cf https://github.com/pytorch/pytorch/pull/5617
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.padding_idx is not None:
 | 
			
		||||
                module.weight.data[module.padding_idx].zero_()
 | 
			
		||||
 | 
			
		||||
        if rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["o_proj.weight", "down_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRN2Model(HGRN2PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: HGRN2Config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        if config.use_lower_bound:
 | 
			
		||||
            self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
 | 
			
		||||
        self.layers = nn.ModuleList([HGRN2Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,  # noqa
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            warnings.warn("`HGRN2Model` does not `output_attentions` now, setting it to `False`.")
 | 
			
		||||
            output_attentions = False
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            batch_size = input_ids.shape[0]
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            batch_size = inputs_embeds.shape[0]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if past_key_values is None:
 | 
			
		||||
                past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_attns = () if output_attentions else None
 | 
			
		||||
 | 
			
		||||
        if self.config.use_lower_bound:
 | 
			
		||||
            lower_bounds = self.lower_bounds.softmax(0)
 | 
			
		||||
            lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
 | 
			
		||||
        for i, layer in enumerate(self.layers):
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
 | 
			
		||||
                    layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    use_cache,
 | 
			
		||||
                    output_attentions,
 | 
			
		||||
                    lower_bound
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states, attentions, past_key_values = layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=attention_mask,
 | 
			
		||||
                    past_key_values=past_key_values,
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                    output_attentions=output_attentions,
 | 
			
		||||
                    lower_bound=lower_bound
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_attns += (attentions,)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = None
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_cache = past_key_values.to_legacy_cache()
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_attns
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HGRN2ForCausalLM(HGRN2PreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.model = HGRN2Model(config)
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def generate(self, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return super().generate(*args, **kwargs)
 | 
			
		||||
        except AttributeError as exception:
 | 
			
		||||
            if 'past_key_values' in str(exception):
 | 
			
		||||
                raise AttributeError(
 | 
			
		||||
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
 | 
			
		||||
                    f"which is not supported for {self.__class__.__name__}. "
 | 
			
		||||
                    f"Try another generation strategy instead. "
 | 
			
		||||
                    f"For the available generation strategies, check this doc: "
 | 
			
		||||
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise exception
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for `inputs_ids` if the `past_key_values` is passed along.
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
 | 
			
		||||
            input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {'inputs_embeds': inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
 | 
			
		||||
            # recompiles graphs as the stride of the inputs is a guard.
 | 
			
		||||
            # Ref: https://github.com/huggingface/transformers/pull/29114
 | 
			
		||||
            # TODO: use `next_tokens` directly instead.
 | 
			
		||||
            model_inputs = {'input_ids': input_ids.contiguous()}
 | 
			
		||||
 | 
			
		||||
        model_inputs.update({
 | 
			
		||||
            'past_key_values': past_key_values,
 | 
			
		||||
            'use_cache': kwargs.get('use_cache'),
 | 
			
		||||
            'attention_mask': attention_mask,
 | 
			
		||||
        })
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            past_key_values=outputs.past_key_values,
 | 
			
		||||
            hidden_states=outputs.hidden_states,
 | 
			
		||||
            attentions=outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										14
									
								
								finetune/lora/v6/fla/models/linear_attn/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								finetune/lora/v6/fla/models/linear_attn/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.linear_attn.configuration_linear_attn import \
 | 
			
		||||
    LinearAttentionConfig
 | 
			
		||||
from fla.models.linear_attn.modeling_linear_attn import (
 | 
			
		||||
    LinearAttentionForCausalLM, LinearAttentionModel)
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig)
 | 
			
		||||
AutoModel.register(LinearAttentionConfig, LinearAttentionModel)
 | 
			
		||||
AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM)
 | 
			
		||||
 | 
			
		||||
__all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel']
 | 
			
		||||
							
								
								
									
										70
									
								
								finetune/lora/v6/fla/models/linear_attn/configuration_linear_attn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								finetune/lora/v6/fla/models/linear_attn/configuration_linear_attn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,70 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearAttentionConfig(PretrainedConfig):
 | 
			
		||||
 | 
			
		||||
    model_type = 'linear_attn'
 | 
			
		||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vocab_size: int = 32000,
 | 
			
		||||
        hidden_size: int = 2048,
 | 
			
		||||
        expand_k: int = 1,
 | 
			
		||||
        expand_v: int = 1,
 | 
			
		||||
        hidden_ratio: Optional[int] = 4,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        num_hidden_layers: int = 24,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        attn_mode: str = "fused_chunk",
 | 
			
		||||
        feature_map: str = "elementwise_product",
 | 
			
		||||
        tie_feature_map_qk: bool = False,
 | 
			
		||||
        norm_q: bool = False,
 | 
			
		||||
        norm_k: bool = False,
 | 
			
		||||
        norm_feature_map: bool = False,
 | 
			
		||||
        hidden_act: str = "swish",
 | 
			
		||||
        max_position_embeddings: int = 2048,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-6,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
        pad_token_id: int = None,
 | 
			
		||||
        bos_token_id: int = 1,
 | 
			
		||||
        eos_token_id: int = 2,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        initializer_range: float = 0.02,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.attn_mode = attn_mode
 | 
			
		||||
        self.feature_map = feature_map
 | 
			
		||||
        self.tie_feature_map_qk = tie_feature_map_qk
 | 
			
		||||
        self.norm_q = norm_q
 | 
			
		||||
        self.norm_k = norm_k
 | 
			
		||||
        self.norm_feature_map = norm_feature_map
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.norm_eps = norm_eps
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            pad_token_id=pad_token_id,
 | 
			
		||||
            bos_token_id=bos_token_id,
 | 
			
		||||
            eos_token_id=eos_token_id,
 | 
			
		||||
            tie_word_embeddings=tie_word_embeddings,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										424
									
								
								finetune/lora/v6/fla/models/linear_attn/modeling_linear_attn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										424
									
								
								finetune/lora/v6/fla/models/linear_attn/modeling_linear_attn.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,424 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.cache_utils import Cache, DynamicCache
 | 
			
		||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
 | 
			
		||||
                                           CausalLMOutputWithPast)
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from fla.layers.linear_attn import LinearAttention
 | 
			
		||||
from fla.models.linear_attn.configuration_linear_attn import \
 | 
			
		||||
    LinearAttentionConfig
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, RMSNorm
 | 
			
		||||
from fla.modules.activations import swiglu_linear
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearAttentionMLP(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'swish'
 | 
			
		||||
    ) -> LinearAttentionMLP:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        # the final number of params is `hidden_ratio * hidden_size^2`
 | 
			
		||||
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 4
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
 | 
			
		||||
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
 | 
			
		||||
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
 | 
			
		||||
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        y = self.gate_proj(x)
 | 
			
		||||
        gate, y = y.chunk(2, -1)
 | 
			
		||||
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearAttentionBlock(nn.Module):
 | 
			
		||||
    def __init__(self, config: LinearAttentionConfig, layer_idx: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.attn = LinearAttention(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            expand_k=config.expand_k,
 | 
			
		||||
            expand_v=config.expand_v,
 | 
			
		||||
            num_heads=config.num_heads,
 | 
			
		||||
            mode=config.attn_mode,
 | 
			
		||||
            feature_map=config.feature_map,
 | 
			
		||||
            tie_feature_map_qk=config.tie_feature_map_qk,
 | 
			
		||||
            norm_q=config.norm_q,
 | 
			
		||||
            norm_k=config.norm_k,
 | 
			
		||||
            do_feature_map_norm=config.norm_feature_map,
 | 
			
		||||
            elementwise_affine=config.elementwise_affine,
 | 
			
		||||
            norm_eps=config.norm_eps,
 | 
			
		||||
            layer_idx=layer_idx
 | 
			
		||||
        )
 | 
			
		||||
        self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.mlp = LinearAttentionMLP(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            hidden_ratio=config.hidden_ratio,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            hidden_act=config.hidden_act
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        # currently not supported
 | 
			
		||||
        attn_weights, present_key_value = None, None
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.attn_norm(hidden_states)
 | 
			
		||||
        hidden_states = self.attn(hidden_states)
 | 
			
		||||
        hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            outputs += (attn_weights,)
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            outputs += (present_key_value,)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearAttentionPreTrainedModel(PreTrainedModel):
 | 
			
		||||
    config_class = LinearAttentionConfig
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ['LinearAttentionBlock']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _init_weights(
 | 
			
		||||
        self,
 | 
			
		||||
        module: nn.Module,
 | 
			
		||||
        rescale_prenorm_residual: bool = True,
 | 
			
		||||
        num_residuals_per_layer: int = 2,
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv1d)):
 | 
			
		||||
            # Slightly different from the TF version which uses truncated_normal for initialization
 | 
			
		||||
            # cf https://github.com/pytorch/pytorch/pull/5617
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.padding_idx is not None:
 | 
			
		||||
                module.weight.data[module.padding_idx].zero_()
 | 
			
		||||
 | 
			
		||||
        if rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["o_proj.weight", "down_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearAttentionModel(LinearAttentionPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: LinearAttentionConfig):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        self.layers = nn.ModuleList(
 | 
			
		||||
            [LinearAttentionBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 | 
			
		||||
        )
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "`LinearAttentionModel` does not support output attention weights now, "
 | 
			
		||||
                "so `output_attentions` is set to `False`."
 | 
			
		||||
            )
 | 
			
		||||
            output_attentions = False
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            _, seq_length = input_ids.shape[:2]
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            _, seq_length = inputs_embeds.shape[:2]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        past_key_values_length = 0
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            use_legacy_cache = not isinstance(past_key_values, Cache)
 | 
			
		||||
            if use_legacy_cache:
 | 
			
		||||
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
			
		||||
            past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
			
		||||
 | 
			
		||||
        if position_ids is None:
 | 
			
		||||
            device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
            position_ids = torch.arange(
 | 
			
		||||
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 | 
			
		||||
            )
 | 
			
		||||
            position_ids = position_ids.unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
 | 
			
		||||
        # embed positions
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        # decoder layers
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_self_attns = () if output_attentions else None
 | 
			
		||||
        next_decoder_cache = None
 | 
			
		||||
 | 
			
		||||
        for decoder_layer in self.layers:
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                layer_outputs = self._gradient_checkpointing_func(
 | 
			
		||||
                    decoder_layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    position_ids,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    output_attentions,
 | 
			
		||||
                    use_cache,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                layer_outputs = decoder_layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=attention_mask,
 | 
			
		||||
                    position_ids=position_ids,
 | 
			
		||||
                    past_key_value=past_key_values,
 | 
			
		||||
                    output_attentions=output_attentions,
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_self_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = None
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_self_attns,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearAttentionForCausalLM(LinearAttentionPreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.model = LinearAttentionModel(config)
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def generate(self, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return super().generate(*args, **kwargs)
 | 
			
		||||
        except AttributeError as exc:
 | 
			
		||||
            # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
 | 
			
		||||
            if 'past_key_values' in str(exc):
 | 
			
		||||
                raise AttributeError(
 | 
			
		||||
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
 | 
			
		||||
                    f"which is not supported for {self.__class__.__name__}. "
 | 
			
		||||
                    f"Try another generation strategy instead. "
 | 
			
		||||
                    f"For the available generation strategies, check this doc: "
 | 
			
		||||
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise exc
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        state: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for inputs_ids if the state is passed along.
 | 
			
		||||
        if state is not None:
 | 
			
		||||
            input_ids = input_ids[:, -1].unsqueeze(-1)
 | 
			
		||||
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and state is None:
 | 
			
		||||
            model_inputs = {"inputs_embeds": inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            model_inputs = {"input_ids": input_ids}
 | 
			
		||||
        model_inputs["state"] = state
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            past_key_values=outputs.past_key_values,
 | 
			
		||||
            hidden_states=outputs.hidden_states,
 | 
			
		||||
            attentions=outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										14
									
								
								finetune/lora/v6/fla/models/mamba/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								finetune/lora/v6/fla/models/mamba/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.mamba.configuration_mamba import MambaConfig
 | 
			
		||||
from fla.models.mamba.modeling_mamba import (MambaBlock, MambaForCausalLM,
 | 
			
		||||
                                             MambaModel)
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(MambaConfig.model_type, MambaConfig, True)
 | 
			
		||||
AutoModel.register(MambaConfig, MambaModel, True)
 | 
			
		||||
AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock']
 | 
			
		||||
							
								
								
									
										156
									
								
								finetune/lora/v6/fla/models/mamba/configuration_mamba.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										156
									
								
								finetune/lora/v6/fla/models/mamba/configuration_mamba.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,156 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2024 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""MAMBA configuration"""
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MambaConfig(PretrainedConfig):
 | 
			
		||||
    """
 | 
			
		||||
    This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA
 | 
			
		||||
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
 | 
			
		||||
    defaults will yield a similar configuration to that of the MAMBA
 | 
			
		||||
    [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture.
 | 
			
		||||
 | 
			
		||||
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
 | 
			
		||||
    documentation from [`PretrainedConfig`] for more information.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        vocab_size (`int`, *optional*, defaults to 50280):
 | 
			
		||||
            Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the
 | 
			
		||||
            `inputs_ids` passed when calling [`MambaModel`].
 | 
			
		||||
        hidden_size (`int`, *optional*, defaults to 768):
 | 
			
		||||
            Dimensionality of the embeddings and hidden states.
 | 
			
		||||
        state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
 | 
			
		||||
        num_hidden_layers (`int`, *optional*, defaults to 32):
 | 
			
		||||
            Number of hidden layers in the model.
 | 
			
		||||
        layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
 | 
			
		||||
            The epsilon to use in the layer normalization layers.
 | 
			
		||||
        pad_token_id (`int`, *optional*, defaults to 0):
 | 
			
		||||
            Padding token id.
 | 
			
		||||
        bos_token_id (`int`, *optional*, defaults to 0):
 | 
			
		||||
            The id of the beginning of sentence token in the vocabulary.
 | 
			
		||||
        eos_token_id (`int`, *optional*, defaults to 0):
 | 
			
		||||
            The id of the end of sentence token in the vocabulary.
 | 
			
		||||
        expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
 | 
			
		||||
        conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
 | 
			
		||||
        use_bias (`bool`, *optional*, defaults to `False`):
 | 
			
		||||
            Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
 | 
			
		||||
        use_conv_bias (`bool`, *optional*, defaults to `True`):
 | 
			
		||||
            Whether or not to use bias in the convolution layer of the mixer block.
 | 
			
		||||
        hidden_act (`str`, *optional*, defaults to `"silu"`):
 | 
			
		||||
            The non-linear activation function (function or string) in the decoder.
 | 
			
		||||
        initializer_range (`float`, *optional*, defaults to 0.1):
 | 
			
		||||
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
 | 
			
		||||
        residual_in_fp32 (`bool`, *optional*, defaults to `True`):
 | 
			
		||||
            Whether or not residuals should be in `float32`.
 | 
			
		||||
            If set to `False` residuals will keep the same `dtype` as the rest of the model
 | 
			
		||||
        time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
 | 
			
		||||
            Rank of the the discretization projection matrix.
 | 
			
		||||
            `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
 | 
			
		||||
        time_step_scale (`float`, *optional*, defaults to 1.0):
 | 
			
		||||
            Scale used used to scale `dt_proj.bias`.
 | 
			
		||||
        time_step_min (`float`, *optional*, defaults to 0.001):
 | 
			
		||||
            Minimum `time_step` used to bound `dt_proj.bias`.
 | 
			
		||||
        time_step_max (`float`, *optional*, defaults to 0.1):
 | 
			
		||||
            Maximum `time_step` used to bound `dt_proj.bias`.
 | 
			
		||||
        time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
 | 
			
		||||
            Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
 | 
			
		||||
        time_step_floor (`float`, *optional*, defaults to 0.0001):
 | 
			
		||||
            Minimum clamping value of the `dt_proj.bias` layer initialization.
 | 
			
		||||
        rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
 | 
			
		||||
            Whether or not to rescale `out_proj` weights when initializing.
 | 
			
		||||
        use_cache (`bool`, *optional*, defaults to `True`):
 | 
			
		||||
            Whether or not the cache should be used.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
 | 
			
		||||
    ```python
 | 
			
		||||
    >>> from transformers import MambaConfig, MambaModel
 | 
			
		||||
 | 
			
		||||
    >>> # Initializing a Mamba configuration
 | 
			
		||||
    >>> configuration = MambaConfig()
 | 
			
		||||
 | 
			
		||||
    >>> # Initializing a model (with random weights) from the configuration
 | 
			
		||||
    >>> model = MambaModel(configuration)
 | 
			
		||||
 | 
			
		||||
    >>> # Accessing the model configuration
 | 
			
		||||
    >>> configuration = model.config
 | 
			
		||||
    ```"""
 | 
			
		||||
 | 
			
		||||
    model_type = "mamba"
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vocab_size=32000,
 | 
			
		||||
        hidden_size=2048,
 | 
			
		||||
        state_size=16,
 | 
			
		||||
        num_hidden_layers=48,
 | 
			
		||||
        layer_norm_epsilon=1e-5,
 | 
			
		||||
        pad_token_id= 0,
 | 
			
		||||
        bos_token_id= 1,
 | 
			
		||||
        eos_token_id= 2,
 | 
			
		||||
        expand=2,
 | 
			
		||||
        conv_kernel=4,
 | 
			
		||||
        use_bias=False,
 | 
			
		||||
        use_conv_bias=True,
 | 
			
		||||
        hidden_act="silu",
 | 
			
		||||
        initializer_range=0.1,
 | 
			
		||||
        residual_in_fp32=False,
 | 
			
		||||
        time_step_rank="auto",
 | 
			
		||||
        time_step_scale=1.0,
 | 
			
		||||
        time_step_min=0.001,
 | 
			
		||||
        time_step_max=0.1,
 | 
			
		||||
        time_step_init_scheme="random",
 | 
			
		||||
        time_step_floor=1e-4,
 | 
			
		||||
        rescale_prenorm_residual=False,
 | 
			
		||||
        use_cache=True,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.state_size = state_size
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.layer_norm_epsilon = layer_norm_epsilon
 | 
			
		||||
        self.conv_kernel = conv_kernel
 | 
			
		||||
        self.expand = expand
 | 
			
		||||
        self.intermediate_size = int(expand * self.hidden_size)
 | 
			
		||||
        self.bos_token_id = bos_token_id
 | 
			
		||||
        self.eos_token_id = eos_token_id
 | 
			
		||||
        self.pad_token_id = pad_token_id
 | 
			
		||||
        self.use_bias = use_bias
 | 
			
		||||
        self.use_conv_bias = use_conv_bias
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
 | 
			
		||||
        self.time_step_scale = time_step_scale
 | 
			
		||||
        self.time_step_min = time_step_min
 | 
			
		||||
        self.time_step_max = time_step_max
 | 
			
		||||
        self.time_step_init_scheme = time_step_init_scheme
 | 
			
		||||
        self.time_step_floor = time_step_floor
 | 
			
		||||
        self.rescale_prenorm_residual = rescale_prenorm_residual
 | 
			
		||||
        self.residual_in_fp32 = residual_in_fp32
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
        self.fuse_norm = fuse_norm
 | 
			
		||||
 | 
			
		||||
        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
 | 
			
		||||
							
								
								
									
										605
									
								
								finetune/lora/v6/fla/models/mamba/modeling_mamba.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										605
									
								
								finetune/lora/v6/fla/models/mamba/modeling_mamba.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,605 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""PyTorch MAMBA model."""
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from torch import nn
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import ModelOutput, logging
 | 
			
		||||
 | 
			
		||||
from fla.models.mamba.configuration_mamba import MambaConfig
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, RMSNorm
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from mamba_ssm.ops.selective_scan_interface import (mamba_inner_fn,
 | 
			
		||||
                                                        selective_scan_fn)
 | 
			
		||||
    from mamba_ssm.ops.triton.selective_state_update import \
 | 
			
		||||
        selective_state_update
 | 
			
		||||
except ImportError:
 | 
			
		||||
    selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
 | 
			
		||||
except ImportError:
 | 
			
		||||
    causal_conv1d_update, causal_conv1d_fn = None, None
 | 
			
		||||
 | 
			
		||||
is_fast_path_available = all(
 | 
			
		||||
    (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MambaCache:
 | 
			
		||||
    def __init__(self, config, batch_size, dtype=torch.float16, device=None):
 | 
			
		||||
        self.seqlen_offset = 0
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
        intermediate_size = config.intermediate_size
 | 
			
		||||
        ssm_state_size = config.state_size
 | 
			
		||||
        conv_kernel_size = config.conv_kernel
 | 
			
		||||
 | 
			
		||||
        self.conv_states = {
 | 
			
		||||
            i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
 | 
			
		||||
            for i in range(config.num_hidden_layers)
 | 
			
		||||
        }
 | 
			
		||||
        self.ssm_states = {
 | 
			
		||||
            i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
 | 
			
		||||
            for i in range(config.num_hidden_layers)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MambaMixer(nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
 | 
			
		||||
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
 | 
			
		||||
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
 | 
			
		||||
    and is why Mamba is called **selective** state spaces)
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config, layer_idx):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
        self.ssm_state_size = config.state_size
 | 
			
		||||
        self.conv_kernel_size = config.conv_kernel
 | 
			
		||||
        self.intermediate_size = config.intermediate_size
 | 
			
		||||
        self.time_step_rank = config.time_step_rank
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
        self.use_conv_bias = config.use_conv_bias
 | 
			
		||||
        self.conv1d = nn.Conv1d(
 | 
			
		||||
            in_channels=self.intermediate_size,
 | 
			
		||||
            out_channels=self.intermediate_size,
 | 
			
		||||
            bias=config.use_conv_bias,
 | 
			
		||||
            kernel_size=config.conv_kernel,
 | 
			
		||||
            groups=self.intermediate_size,
 | 
			
		||||
            padding=config.conv_kernel - 1,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.activation = config.hidden_act
 | 
			
		||||
        self.act = ACT2FN[config.hidden_act]
 | 
			
		||||
 | 
			
		||||
        # projection of the input hidden states
 | 
			
		||||
        self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
 | 
			
		||||
        # selective projection used to make dt, B and C input dependant
 | 
			
		||||
        self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
 | 
			
		||||
        # time step projection (discretization)
 | 
			
		||||
        self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
 | 
			
		||||
 | 
			
		||||
        # S4D real initialization. These are not discretized!
 | 
			
		||||
        # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
 | 
			
		||||
        A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
 | 
			
		||||
        A = A.expand(self.intermediate_size, -1).contiguous()
 | 
			
		||||
 | 
			
		||||
        self.A_log = nn.Parameter(torch.log(A))
 | 
			
		||||
        self.D = nn.Parameter(torch.ones(self.intermediate_size))
 | 
			
		||||
        self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
 | 
			
		||||
        self.use_bias = config.use_bias
 | 
			
		||||
 | 
			
		||||
        if not is_fast_path_available:
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "The fast path is not available because on of "
 | 
			
		||||
                "`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
 | 
			
		||||
                " is None. Falling back to the naive implementation. "
 | 
			
		||||
                "To install follow https://github.com/state-spaces/mamba/#installation and"
 | 
			
		||||
                " https://github.com/Dao-AILab/causal-conv1d"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None):
 | 
			
		||||
        # 1. Gated MLP's linear projection
 | 
			
		||||
        projected_states = self.in_proj(hidden_states).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        if self.training and cache_params is None:  # Doesn't support outputting the states -> used for training
 | 
			
		||||
            contextualized_states = mamba_inner_fn(
 | 
			
		||||
                projected_states,
 | 
			
		||||
                self.conv1d.weight,
 | 
			
		||||
                self.conv1d.bias if self.use_conv_bias else None,
 | 
			
		||||
                self.x_proj.weight,
 | 
			
		||||
                self.dt_proj.weight,
 | 
			
		||||
                self.out_proj.weight,
 | 
			
		||||
                self.out_proj.bias.float() if self.use_bias else None,
 | 
			
		||||
                -torch.exp(self.A_log.float()),
 | 
			
		||||
                None,  # input-dependent B
 | 
			
		||||
                None,  # input-dependent C
 | 
			
		||||
                self.D.float(),
 | 
			
		||||
                delta_bias=self.dt_proj.bias.float(),
 | 
			
		||||
                delta_softplus=True,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states, gate = projected_states.chunk(2, dim=1)
 | 
			
		||||
 | 
			
		||||
            # 2. Convolution sequence transformation
 | 
			
		||||
            conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
 | 
			
		||||
            if cache_params is not None and cache_params.seqlen_offset > 0:
 | 
			
		||||
                hidden_states = causal_conv1d_update(
 | 
			
		||||
                    hidden_states.squeeze(-1),
 | 
			
		||||
                    cache_params.conv_states[self.layer_idx],
 | 
			
		||||
                    conv_weights,
 | 
			
		||||
                    self.conv1d.bias,
 | 
			
		||||
                    self.activation,
 | 
			
		||||
                )
 | 
			
		||||
                hidden_states = hidden_states.unsqueeze(-1)
 | 
			
		||||
            else:
 | 
			
		||||
                if cache_params is not None:
 | 
			
		||||
                    conv_states = nn.functional.pad(
 | 
			
		||||
                        hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
 | 
			
		||||
                    )
 | 
			
		||||
                    cache_params.conv_states[self.layer_idx].copy_(conv_states)
 | 
			
		||||
                hidden_states = causal_conv1d_fn(
 | 
			
		||||
                    hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # 3. State Space Model sequence transformation
 | 
			
		||||
            # 3.a. input varying initialization of time_step, B and C
 | 
			
		||||
            ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
 | 
			
		||||
            time_step, B, C = torch.split(
 | 
			
		||||
                ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
 | 
			
		||||
            )
 | 
			
		||||
            discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
            A = -torch.exp(self.A_log.float())
 | 
			
		||||
            # 3.c perform the recurrence y ← SSM(A, B, C)(x)
 | 
			
		||||
            time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
 | 
			
		||||
            if cache_params is not None and cache_params.seqlen_offset > 0:
 | 
			
		||||
                scan_outputs = selective_state_update(
 | 
			
		||||
                    cache_params.ssm_states[self.layer_idx],
 | 
			
		||||
                    hidden_states[..., 0],
 | 
			
		||||
                    discrete_time_step[..., 0],
 | 
			
		||||
                    A,
 | 
			
		||||
                    B[:, 0],
 | 
			
		||||
                    C[:, 0],
 | 
			
		||||
                    self.D,
 | 
			
		||||
                    gate[..., 0],
 | 
			
		||||
                    time_proj_bias,
 | 
			
		||||
                    dt_softplus=True,
 | 
			
		||||
                ).unsqueeze(-1)
 | 
			
		||||
            else:
 | 
			
		||||
                scan_outputs, ssm_state = selective_scan_fn(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    discrete_time_step,
 | 
			
		||||
                    A,
 | 
			
		||||
                    B.transpose(1, 2),
 | 
			
		||||
                    C.transpose(1, 2),
 | 
			
		||||
                    self.D.float(),
 | 
			
		||||
                    gate,
 | 
			
		||||
                    time_proj_bias,
 | 
			
		||||
                    delta_softplus=True,
 | 
			
		||||
                    return_last_state=True,
 | 
			
		||||
                )
 | 
			
		||||
                if ssm_state is not None and cache_params is not None:
 | 
			
		||||
                    cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
 | 
			
		||||
 | 
			
		||||
            # 4. Final linear projection
 | 
			
		||||
            contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
 | 
			
		||||
        return contextualized_states
 | 
			
		||||
 | 
			
		||||
    # fmt: off
 | 
			
		||||
    def slow_forward(self, input_states, cache_params: Optional[MambaCache] = None):
 | 
			
		||||
        batch_size, seq_len, _ = input_states.shape
 | 
			
		||||
        dtype = input_states.dtype
 | 
			
		||||
        # 1. Gated MLP's linear projection
 | 
			
		||||
        # [batch, 2 * intermediate_size, seq_len]
 | 
			
		||||
        projected_states = self.in_proj(input_states).transpose(1, 2)
 | 
			
		||||
        hidden_states, gate = projected_states.chunk(2, dim=1)
 | 
			
		||||
 | 
			
		||||
        # 2. Convolution sequence transformation
 | 
			
		||||
        if cache_params is not None:
 | 
			
		||||
            ssm_state = cache_params.ssm_states[self.layer_idx].clone()
 | 
			
		||||
            if cache_params.seqlen_offset > 0:
 | 
			
		||||
                # [batch, intermediate_size, conv_kernel_size]
 | 
			
		||||
                conv_state = cache_params.conv_states[self.layer_idx]
 | 
			
		||||
                conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
 | 
			
		||||
                conv_state[:, :, -1] = hidden_states[:, :, 0]
 | 
			
		||||
                cache_params.conv_states[self.layer_idx].copy_(conv_state)
 | 
			
		||||
                hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
 | 
			
		||||
                if self.use_conv_bias:
 | 
			
		||||
                    hidden_states += self.conv1d.bias
 | 
			
		||||
                # [batch, intermediate_size, 1] : decoding
 | 
			
		||||
                hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)
 | 
			
		||||
            else:
 | 
			
		||||
                conv_state = nn.functional.pad(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    (self.conv_kernel_size - hidden_states.shape[-1], 0)
 | 
			
		||||
                )
 | 
			
		||||
                cache_params.conv_states[self.layer_idx].copy_(conv_state)
 | 
			
		||||
                # [batch, intermediate_size, seq_len]
 | 
			
		||||
                hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
 | 
			
		||||
        else:
 | 
			
		||||
            ssm_state = torch.zeros(
 | 
			
		||||
                (batch_size, self.intermediate_size, self.ssm_state_size),
 | 
			
		||||
                device=hidden_states.device, dtype=dtype
 | 
			
		||||
            )
 | 
			
		||||
            # [batch, intermediate_size, seq_len]
 | 
			
		||||
            hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
 | 
			
		||||
 | 
			
		||||
        # 3. State Space Model sequence transformation
 | 
			
		||||
        # 3.a. Selection:  [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
 | 
			
		||||
        ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
 | 
			
		||||
        time_step, B, C = torch.split(
 | 
			
		||||
            ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
 | 
			
		||||
        )
 | 
			
		||||
        # [batch, seq_len, intermediate_size]
 | 
			
		||||
        discrete_time_step = self.dt_proj(time_step)
 | 
			
		||||
        # [batch, intermediate_size, seq_len]
 | 
			
		||||
        discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
 | 
			
		||||
        # [intermediate_size, ssm_state_size]
 | 
			
		||||
        A = -torch.exp(self.A_log.float())
 | 
			
		||||
        # [batch, intermediate_size, seq_len, ssm_state_size]
 | 
			
		||||
        discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None])
 | 
			
		||||
        # [batch, intermediade_size, seq_len, ssm_state_size]
 | 
			
		||||
        discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
 | 
			
		||||
        deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
 | 
			
		||||
 | 
			
		||||
        # 3.c perform the recurrence y ← SSM(A, B, C)(x)
 | 
			
		||||
        scan_outputs = []
 | 
			
		||||
        for i in range(seq_len):
 | 
			
		||||
            # [batch, intermediade_size, ssm_state]
 | 
			
		||||
            ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
 | 
			
		||||
            # [batch, intermediade_size, 1]
 | 
			
		||||
            scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))
 | 
			
		||||
            scan_outputs.append(scan_output[:, :, 0])
 | 
			
		||||
        # [batch, seq_len, intermediade_size]
 | 
			
		||||
        scan_output = torch.stack(scan_outputs, dim=-1)
 | 
			
		||||
        scan_output = scan_output + (hidden_states * self.D[None, :, None])
 | 
			
		||||
        scan_output = (scan_output * self.act(gate))
 | 
			
		||||
 | 
			
		||||
        if cache_params is not None:
 | 
			
		||||
            cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
 | 
			
		||||
 | 
			
		||||
        # 4. Final linear projection
 | 
			
		||||
        # [batch, seq_len, hidden_size]
 | 
			
		||||
        contextualized_states = self.out_proj(scan_output.transpose(1, 2))
 | 
			
		||||
        return contextualized_states
 | 
			
		||||
    # fmt: on
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
 | 
			
		||||
        if is_fast_path_available and "cuda" in self.x_proj.weight.device.type:
 | 
			
		||||
            return self.cuda_kernels_forward(hidden_states, cache_params)
 | 
			
		||||
        return self.slow_forward(hidden_states, cache_params)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MambaBlock(nn.Module):
 | 
			
		||||
    def __init__(self, config, layer_idx):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
        self.residual_in_fp32 = config.residual_in_fp32
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
 | 
			
		||||
        self.mixer = MambaMixer(config, layer_idx=layer_idx)
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states, cache_params: Optional[MambaCache] = None):
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
        # if self.residual_in_fp32:
 | 
			
		||||
        #     residual = residual.to(torch.float32)
 | 
			
		||||
        hidden_states = self.mixer(hidden_states, cache_params=cache_params)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
        return hidden_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MambaPreTrainedModel(PreTrainedModel):
 | 
			
		||||
    """
 | 
			
		||||
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
 | 
			
		||||
    models.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    config_class = MambaConfig
 | 
			
		||||
    base_model_prefix = "backbone"
 | 
			
		||||
    _no_split_modules = ["MambaBlock"]
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
 | 
			
		||||
    def _init_weights(self, module):
 | 
			
		||||
        """Initialize the weights."""
 | 
			
		||||
        if isinstance(module, MambaMixer):
 | 
			
		||||
            module.A_log._no_weight_decay = True
 | 
			
		||||
            module.D._no_weight_decay = True
 | 
			
		||||
 | 
			
		||||
            dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
 | 
			
		||||
            if self.config.time_step_init_scheme == "constant":
 | 
			
		||||
                nn.init.constant_(module.dt_proj.weight, dt_init_std)
 | 
			
		||||
            elif self.config.time_step_init_scheme == "random":
 | 
			
		||||
                nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
 | 
			
		||||
 | 
			
		||||
            dt = torch.exp(
 | 
			
		||||
                torch.rand(self.config.intermediate_size)
 | 
			
		||||
                * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
 | 
			
		||||
                + math.log(self.config.time_step_min)
 | 
			
		||||
            ).clamp(min=self.config.time_step_floor)
 | 
			
		||||
            # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
 | 
			
		||||
            inv_dt = dt + torch.log(-torch.expm1(-dt))
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
                module.dt_proj.bias.copy_(inv_dt)
 | 
			
		||||
            module.dt_proj.bias._no_reinit = True
 | 
			
		||||
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                if not getattr(module.bias, "_no_reinit", False):
 | 
			
		||||
                    nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, std=self.config.initializer_range)
 | 
			
		||||
 | 
			
		||||
        if self.config.rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["out_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    nn.init.kaiming_uniform_(p, a=math.sqrt(5))
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(self.config.num_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class MambaOutput(ModelOutput):
 | 
			
		||||
    """
 | 
			
		||||
    Class for the MAMBA model outputs.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
 | 
			
		||||
            Sequence of hidden-states at the output of the last layer of the model.
 | 
			
		||||
        cache_params (`MambaCache`):
 | 
			
		||||
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
 | 
			
		||||
            avoid providing the old `input_ids`.
 | 
			
		||||
 | 
			
		||||
            Includes both the State space model state matrices after the selective scan, and the Convolutional states
 | 
			
		||||
        hidden_states (`tuple(torch.FloatTensor)`, *optional*,
 | 
			
		||||
            returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
 | 
			
		||||
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
 | 
			
		||||
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
 | 
			
		||||
 | 
			
		||||
            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    last_hidden_state: Optional[torch.FloatTensor] = None
 | 
			
		||||
    cache_params: Optional[MambaCache] = None
 | 
			
		||||
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class MambaCausalLMOutput(ModelOutput):
 | 
			
		||||
    """
 | 
			
		||||
    Base class for causal language model (or autoregressive) outputs.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
 | 
			
		||||
            Language modeling loss (for next-token prediction).
 | 
			
		||||
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
 | 
			
		||||
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
 | 
			
		||||
        cache_params (`MambaCache`):
 | 
			
		||||
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
 | 
			
		||||
            avoid providing the old `input_ids`.
 | 
			
		||||
 | 
			
		||||
            Includes both the State space model state matrices after the selective scan, and the Convolutional states
 | 
			
		||||
        hidden_states (`tuple(torch.FloatTensor)`, *optional*,
 | 
			
		||||
            returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
 | 
			
		||||
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
 | 
			
		||||
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
 | 
			
		||||
 | 
			
		||||
            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    loss: Optional[torch.FloatTensor] = None
 | 
			
		||||
    logits: Optional[torch.FloatTensor] = None
 | 
			
		||||
    cache_params: Optional[MambaCache] = None
 | 
			
		||||
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MambaModel(MambaPreTrainedModel):
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
 | 
			
		||||
        self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
        self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, new_embeddings):
 | 
			
		||||
        self.embeddings = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.LongTensor] = None,
 | 
			
		||||
        cache_params: Optional[MambaCache] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
        **kwargs,  # `attention_mask` is passed by the tokenizer and we don't want it
 | 
			
		||||
    ) -> Union[Tuple, MambaOutput]:
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        if (input_ids is None) ^ (inputs_embeds is not None):  # ^ is python for xor
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training and use_cache:
 | 
			
		||||
            use_cache = False
 | 
			
		||||
 | 
			
		||||
        if cache_params is None and use_cache:
 | 
			
		||||
            cache_params = MambaCache(
 | 
			
		||||
                self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        for mixer_block in self.layers:
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params)
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states = mixer_block(hidden_states, cache_params=cache_params)
 | 
			
		||||
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states = all_hidden_states + (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            cache_params.seqlen_offset += inputs_embeds.shape[1]
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm_f(hidden_states)
 | 
			
		||||
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states = all_hidden_states + (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
 | 
			
		||||
 | 
			
		||||
        return MambaOutput(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            cache_params=cache_params if use_cache else None,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MambaForCausalLM(MambaPreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.backbone = MambaModel(config)
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.backbone.get_input_embeddings()
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, new_embeddings):
 | 
			
		||||
        return self.backbone.set_input_embeddings(new_embeddings)
 | 
			
		||||
 | 
			
		||||
    def _update_model_kwargs_for_generation(
 | 
			
		||||
        self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        model_kwargs["cache_params"] = outputs.get("cache_params", None)
 | 
			
		||||
        return model_kwargs
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for inputs_ids if the state is passed along.
 | 
			
		||||
        if cache_params is not None:
 | 
			
		||||
            input_ids = input_ids[:, -1].unsqueeze(-1)
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is not None and cache_params is None:
 | 
			
		||||
            model_inputs = {"inputs_embeds": inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            model_inputs = {"input_ids": input_ids}
 | 
			
		||||
 | 
			
		||||
        model_inputs["cache_params"] = cache_params
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        cache_params: Optional[MambaCache] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        **kwargs,  # for now we need this for generation
 | 
			
		||||
    ) -> Union[Tuple, MambaCausalLMOutput]:
 | 
			
		||||
        r"""
 | 
			
		||||
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 | 
			
		||||
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
 | 
			
		||||
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
 | 
			
		||||
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
 | 
			
		||||
        """
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        mamba_outputs = self.backbone(
 | 
			
		||||
            input_ids,
 | 
			
		||||
            cache_params=cache_params,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states = mamba_outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + mamba_outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return MambaCausalLMOutput(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            cache_params=mamba_outputs.cache_params,
 | 
			
		||||
            hidden_states=mamba_outputs.hidden_states,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										13
									
								
								finetune/lora/v6/fla/models/retnet/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								finetune/lora/v6/fla/models/retnet/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.retnet.configuration_retnet import RetNetConfig
 | 
			
		||||
from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(RetNetConfig.model_type, RetNetConfig)
 | 
			
		||||
AutoModel.register(RetNetConfig, RetNetModel)
 | 
			
		||||
AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel']
 | 
			
		||||
							
								
								
									
										76
									
								
								finetune/lora/v6/fla/models/retnet/configuration_retnet.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								finetune/lora/v6/fla/models/retnet/configuration_retnet.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,76 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RetNetConfig(PretrainedConfig):
 | 
			
		||||
 | 
			
		||||
    model_type = 'retnet'
 | 
			
		||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vocab_size: int = 32000,
 | 
			
		||||
        hidden_size: int = 2048,
 | 
			
		||||
        expand_k: int = 1,
 | 
			
		||||
        expand_v: int = 2,
 | 
			
		||||
        hidden_ratio: Optional[int] = 2,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        num_hidden_layers: int = 24,
 | 
			
		||||
        num_heads: int = 8,
 | 
			
		||||
        num_kv_heads: Optional[int] = None,
 | 
			
		||||
        feature_map: Optional[str] = None,
 | 
			
		||||
        attn_mode: str = "fused_chunk",
 | 
			
		||||
        hidden_act: str = "swish",
 | 
			
		||||
        use_short_conv: bool = False,
 | 
			
		||||
        conv_size: int = 4,
 | 
			
		||||
        share_conv_kernel: bool = True,
 | 
			
		||||
        use_output_gate: bool = True,
 | 
			
		||||
        max_position_embeddings: int = 2048,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-6,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
        pad_token_id: int = None,
 | 
			
		||||
        bos_token_id: int = 1,
 | 
			
		||||
        eos_token_id: int = 2,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        initializer_range: float = 0.02,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ) -> RetNetConfig:
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.num_kv_heads = num_kv_heads
 | 
			
		||||
        self.feature_map = feature_map
 | 
			
		||||
        self.attn_mode = attn_mode
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.use_short_conv = use_short_conv
 | 
			
		||||
        self.conv_size = conv_size
 | 
			
		||||
        self.share_conv_kernel = share_conv_kernel
 | 
			
		||||
        self.use_output_gate = use_output_gate
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.norm_eps = norm_eps
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.fuse_norm = fuse_norm
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            pad_token_id=pad_token_id,
 | 
			
		||||
            bos_token_id=bos_token_id,
 | 
			
		||||
            eos_token_id=eos_token_id,
 | 
			
		||||
            tie_word_embeddings=tie_word_embeddings,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										410
									
								
								finetune/lora/v6/fla/models/retnet/modeling_retnet.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										410
									
								
								finetune/lora/v6/fla/models/retnet/modeling_retnet.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,410 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
 | 
			
		||||
                                           CausalLMOutputWithPast)
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from fla.layers.multiscale_retention import MultiScaleRetention
 | 
			
		||||
from fla.models.retnet.configuration_retnet import RetNetConfig
 | 
			
		||||
from fla.models.utils import RecurrentCache
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, RMSNorm
 | 
			
		||||
from fla.modules.activations import swiglu_linear
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RetNetMLP(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'swish'
 | 
			
		||||
    ) -> RetNetMLP:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        # the final number of params is `hidden_ratio * hidden_size^2`
 | 
			
		||||
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 4
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
 | 
			
		||||
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
 | 
			
		||||
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
 | 
			
		||||
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        y = self.gate_proj(x)
 | 
			
		||||
        gate, y = y.chunk(2, -1)
 | 
			
		||||
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RetNetBlock(nn.Module):
 | 
			
		||||
    def __init__(self, config: RetNetConfig, layer_idx: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.attn = MultiScaleRetention(
 | 
			
		||||
            mode=config.attn_mode,
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            expand_k=config.expand_k,
 | 
			
		||||
            expand_v=config.expand_v,
 | 
			
		||||
            num_heads=config.num_heads,
 | 
			
		||||
            num_kv_heads=config.num_kv_heads,
 | 
			
		||||
            feature_map=config.feature_map,
 | 
			
		||||
            use_output_gate=config.use_output_gate,
 | 
			
		||||
            gate_fn=config.hidden_act,
 | 
			
		||||
            elementwise_affine=config.elementwise_affine,
 | 
			
		||||
            norm_eps=config.norm_eps,
 | 
			
		||||
            fuse_norm=config.fuse_norm,
 | 
			
		||||
            layer_idx=layer_idx
 | 
			
		||||
        )
 | 
			
		||||
        self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.mlp = RetNetMLP(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            hidden_ratio=config.hidden_ratio,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            hidden_act=config.hidden_act
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.attn_norm(hidden_states)
 | 
			
		||||
        hidden_states, attentions, past_key_values = self.attn(
 | 
			
		||||
            hidden_states=hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states, attentions, past_key_values)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RetNetPreTrainedModel(PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    config_class = RetNetConfig
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ['RetNetBlock']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _init_weights(
 | 
			
		||||
        self,
 | 
			
		||||
        module: nn.Module,
 | 
			
		||||
        rescale_prenorm_residual: bool = True,
 | 
			
		||||
        num_residuals_per_layer: int = 2,
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv1d)):
 | 
			
		||||
            # Slightly different from the TF version which uses truncated_normal for initialization
 | 
			
		||||
            # cf https://github.com/pytorch/pytorch/pull/5617
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.padding_idx is not None:
 | 
			
		||||
                module.weight.data[module.padding_idx].zero_()
 | 
			
		||||
 | 
			
		||||
        if rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["o_proj.weight", "down_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RetNetModel(RetNetPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: RetNetConfig):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        self.layers = nn.ModuleList(
 | 
			
		||||
            [RetNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 | 
			
		||||
        )
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,  # noqa
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "`RetNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
 | 
			
		||||
            )
 | 
			
		||||
            output_attentions = False
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            batch_size, seq_len = input_ids.shape[:2]
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            batch_size, seq_len = inputs_embeds.shape[:2]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if past_key_values is None:
 | 
			
		||||
                past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_attns = () if output_attentions else None
 | 
			
		||||
        for layer in self.layers:
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
 | 
			
		||||
                    layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    use_cache,
 | 
			
		||||
                    output_attentions
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states, attentions, past_key_values = layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=attention_mask,
 | 
			
		||||
                    past_key_values=past_key_values,
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                    output_attentions=output_attentions
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_attns += (attentions,)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = None
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_cache = past_key_values.to_legacy_cache()
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_attns
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RetNetForCausalLM(RetNetPreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.model = RetNetModel(config)
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def generate(self, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return super().generate(*args, **kwargs)
 | 
			
		||||
        except AttributeError as exception:
 | 
			
		||||
            # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
 | 
			
		||||
            if 'past_key_values' in str(exception):
 | 
			
		||||
                raise AttributeError(
 | 
			
		||||
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
 | 
			
		||||
                    f"which is not supported for {self.__class__.__name__}. "
 | 
			
		||||
                    f"Try another generation strategy instead. "
 | 
			
		||||
                    f"For the available generation strategies, check this doc: "
 | 
			
		||||
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise exception
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        past_key_values: Optional[torch.Tensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for `inputs_ids` if the `past_key_values` is passed along.
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
 | 
			
		||||
            input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
 | 
			
		||||
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {'inputs_embeds': inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
 | 
			
		||||
            # recompiles graphs as the stride of the inputs is a guard.
 | 
			
		||||
            # Ref: https://github.com/huggingface/transformers/pull/29114
 | 
			
		||||
            # TODO: use `next_tokens` directly instead.
 | 
			
		||||
            model_inputs = {'input_ids': input_ids.contiguous()}
 | 
			
		||||
 | 
			
		||||
        model_inputs.update({
 | 
			
		||||
            'past_key_values': past_key_values,
 | 
			
		||||
            'use_cache': kwargs.get('use_cache'),
 | 
			
		||||
            'attention_mask': attention_mask,
 | 
			
		||||
        })
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            past_key_values=outputs.past_key_values,
 | 
			
		||||
            hidden_states=outputs.hidden_states,
 | 
			
		||||
            attentions=outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										13
									
								
								finetune/lora/v6/fla/models/rwkv6/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								finetune/lora/v6/fla/models/rwkv6/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config
 | 
			
		||||
from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(RWKV6Config.model_type, RWKV6Config)
 | 
			
		||||
AutoModel.register(RWKV6Config, RWKV6Model)
 | 
			
		||||
AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model']
 | 
			
		||||
							
								
								
									
										66
									
								
								finetune/lora/v6/fla/models/rwkv6/configuration_rwkv6.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								finetune/lora/v6/fla/models/rwkv6/configuration_rwkv6.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,66 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKV6Config(PretrainedConfig):
 | 
			
		||||
 | 
			
		||||
    model_type = 'rwkv6'
 | 
			
		||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        attn_mode: str = "chunk",
 | 
			
		||||
        vocab_size: int = 32000,
 | 
			
		||||
        hidden_size: int = 2048,
 | 
			
		||||
        expand_k: int = 0.5,
 | 
			
		||||
        expand_v: int = 1,
 | 
			
		||||
        hidden_ratio: Optional[int] = 3.5,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        use_glu: Optional[bool] = False,
 | 
			
		||||
        num_hidden_layers: int = 24,
 | 
			
		||||
        num_heads: int = 4,
 | 
			
		||||
        proj_low_rank_dim: int = 32,
 | 
			
		||||
        gate_low_rank_dim: int = 64,
 | 
			
		||||
        hidden_act: str = "sqrelu",
 | 
			
		||||
        max_position_embeddings: int = 2048,
 | 
			
		||||
        eps: float = 1e-6,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
        pad_token_id: int = None,
 | 
			
		||||
        bos_token_id: int = 1,
 | 
			
		||||
        eos_token_id: int = 2,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        initializer_range: float = 0.02,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.expand_k = expand_k
 | 
			
		||||
        self.expand_v = expand_v
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.use_glu = use_glu
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.proj_low_rank_dim = proj_low_rank_dim
 | 
			
		||||
        self.gate_low_rank_dim = gate_low_rank_dim
 | 
			
		||||
        self.attn_mode = attn_mode
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.fuse_norm = fuse_norm
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            pad_token_id=pad_token_id,
 | 
			
		||||
            bos_token_id=bos_token_id,
 | 
			
		||||
            eos_token_id=eos_token_id,
 | 
			
		||||
            tie_word_embeddings=tie_word_embeddings,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										443
									
								
								finetune/lora/v6/fla/models/rwkv6/modeling_rwkv6.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										443
									
								
								finetune/lora/v6/fla/models/rwkv6/modeling_rwkv6.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,443 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
 | 
			
		||||
                                           CausalLMOutputWithPast)
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from fla.layers.rwkv6 import LerpLinear, RWKV6Attention
 | 
			
		||||
from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config
 | 
			
		||||
from fla.models.utils import RecurrentCache
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, LayerNorm
 | 
			
		||||
from fla.modules.activations import ACT2FN, swiglu_linear
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKV6FeedForward(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'sqrelu',
 | 
			
		||||
        layer_idx: int = None
 | 
			
		||||
    ) -> RWKV6FeedForward:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 3.5
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio)
 | 
			
		||||
            intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
 | 
			
		||||
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
			
		||||
 | 
			
		||||
        self.key = LerpLinear(hidden_size, intermediate_size)
 | 
			
		||||
        self.value = nn.Linear(intermediate_size, hidden_size)
 | 
			
		||||
        self.receptance = LerpLinear(hidden_size, hidden_size)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None) -> torch.Tensor:
 | 
			
		||||
        if state is not None:
 | 
			
		||||
            raise NotImplementedError("Past state is not yet supported in `RWKV6FeedForward`.")
 | 
			
		||||
        shifted = self.time_shift(x)
 | 
			
		||||
        if len(shifted.shape) == 2:
 | 
			
		||||
            shifted = shifted.unsqueeze(1)
 | 
			
		||||
        delta = shifted - x
 | 
			
		||||
        key = self.act_fn(self.key(x, delta))
 | 
			
		||||
        value = self.value(key)
 | 
			
		||||
        receptance = self.receptance(x, delta)
 | 
			
		||||
        return receptance.sigmoid() * value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKV6GLU(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'swish',
 | 
			
		||||
        layer_idx: int = None
 | 
			
		||||
    ) -> RWKV6GLU:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        # the final number of params is `hidden_ratio * hidden_size^2`
 | 
			
		||||
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 4
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
 | 
			
		||||
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
 | 
			
		||||
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        y = self.gate_proj(x)
 | 
			
		||||
        gate, y = y.chunk(2, -1)
 | 
			
		||||
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKV6Block(nn.Module):
 | 
			
		||||
    def __init__(self, config: RWKV6Config, layer_idx: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.attn_norm = LayerNorm(hidden_size=config.hidden_size, eps=config.eps)
 | 
			
		||||
        self.attn = RWKV6Attention(
 | 
			
		||||
            mode=config.attn_mode,
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            expand_k=config.expand_k,
 | 
			
		||||
            expand_v=config.expand_v,
 | 
			
		||||
            num_heads=config.num_heads,
 | 
			
		||||
            proj_low_rank_dim=config.proj_low_rank_dim,
 | 
			
		||||
            gate_low_rank_dim=config.gate_low_rank_dim,
 | 
			
		||||
            eps=config.eps,
 | 
			
		||||
            fuse_norm=config.fuse_norm,
 | 
			
		||||
            layer_idx=layer_idx
 | 
			
		||||
        )
 | 
			
		||||
        self.ffn_norm = LayerNorm(hidden_size=config.hidden_size, eps=config.eps)
 | 
			
		||||
        self.ffn = (RWKV6GLU if config.use_glu else RWKV6FeedForward)(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            hidden_ratio=config.hidden_ratio,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            hidden_act=config.hidden_act,
 | 
			
		||||
            layer_idx=layer_idx
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        hidden_states = self.attn_norm(hidden_states)
 | 
			
		||||
        hidden_states, attentions, past_key_values = self.attn(
 | 
			
		||||
            hidden_states=hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states, residual = self.ffn_norm(hidden_states, residual, True)
 | 
			
		||||
        hidden_states = self.ffn(hidden_states)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states, attentions, past_key_values)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKV6PreTrainedModel(PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    config_class = RWKV6Config
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ['RWKV6Block']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _init_weights(
 | 
			
		||||
        self,
 | 
			
		||||
        module: nn.Module,
 | 
			
		||||
        rescale_prenorm_residual: bool = True,
 | 
			
		||||
        num_residuals_per_layer: int = 2,
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv1d)):
 | 
			
		||||
            # Slightly different from the TF version which uses truncated_normal for initialization
 | 
			
		||||
            # cf https://github.com/pytorch/pytorch/pull/5617
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Parameter):
 | 
			
		||||
            nn.init.normal_(module, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.padding_idx is not None:
 | 
			
		||||
                module.weight.data[module.padding_idx].zero_()
 | 
			
		||||
 | 
			
		||||
        if rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["o_proj.weight", "down_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKV6Model(RWKV6PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: RWKV6Config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        self.layers = nn.ModuleList([RWKV6Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
 | 
			
		||||
        self.norm = LayerNorm(config.hidden_size, eps=config.eps)
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,  # noqa
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            warnings.warn("`RWKV6Model` does not `output_attentions` now, setting it to `False`.")
 | 
			
		||||
            output_attentions = False
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            batch_size = input_ids.shape[0]
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            batch_size = inputs_embeds.shape[0]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if past_key_values is None:
 | 
			
		||||
                past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers]
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_attns = () if output_attentions else None
 | 
			
		||||
        for layer in self.layers:
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
 | 
			
		||||
                    layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    use_cache,
 | 
			
		||||
                    output_attentions
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hidden_states, attentions, past_key_values = layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=attention_mask,
 | 
			
		||||
                    past_key_values=past_key_values,
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                    output_attentions=output_attentions
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_attns += (attentions,)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = None
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_cache = past_key_values.to_legacy_cache()
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_attns
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKV6ForCausalLM(RWKV6PreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.model = RWKV6Model(config)
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def generate(self, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return super().generate(*args, **kwargs)
 | 
			
		||||
        except AttributeError as exception:
 | 
			
		||||
            if 'past_key_values' in str(exception):
 | 
			
		||||
                raise AttributeError(
 | 
			
		||||
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
 | 
			
		||||
                    f"which is not supported for {self.__class__.__name__}. "
 | 
			
		||||
                    f"Try another generation strategy instead. "
 | 
			
		||||
                    f"For the available generation strategies, check this doc: "
 | 
			
		||||
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise exception
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for `inputs_ids` if the `past_key_values` is passed along.
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            if not isinstance(past_key_values, RecurrentCache):
 | 
			
		||||
                past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1)
 | 
			
		||||
            input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:]
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {'inputs_embeds': inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
 | 
			
		||||
            # recompiles graphs as the stride of the inputs is a guard.
 | 
			
		||||
            # Ref: https://github.com/huggingface/transformers/pull/29114
 | 
			
		||||
            # TODO: use `next_tokens` directly instead.
 | 
			
		||||
            model_inputs = {'input_ids': input_ids.contiguous()}
 | 
			
		||||
 | 
			
		||||
        model_inputs.update({
 | 
			
		||||
            'past_key_values': past_key_values,
 | 
			
		||||
            'use_cache': kwargs.get('use_cache'),
 | 
			
		||||
            'attention_mask': attention_mask,
 | 
			
		||||
        })
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[List[torch.Tensor]]] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            past_key_values=outputs.past_key_values,
 | 
			
		||||
            hidden_states=outputs.hidden_states,
 | 
			
		||||
            attentions=outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										14
									
								
								finetune/lora/v6/fla/models/transformer/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								finetune/lora/v6/fla/models/transformer/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
from fla.models.transformer.configuration_transformer import TransformerConfig
 | 
			
		||||
from fla.models.transformer.modeling_transformer import (
 | 
			
		||||
    TransformerForCausalLM, TransformerModel)
 | 
			
		||||
 | 
			
		||||
AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
 | 
			
		||||
AutoModel.register(TransformerConfig, TransformerModel)
 | 
			
		||||
AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel']
 | 
			
		||||
							
								
								
									
										61
									
								
								finetune/lora/v6/fla/models/transformer/configuration_transformer.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								finetune/lora/v6/fla/models/transformer/configuration_transformer.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,61 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerConfig(PretrainedConfig):
 | 
			
		||||
 | 
			
		||||
    model_type = 'transformer'
 | 
			
		||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vocab_size: int = 32000,
 | 
			
		||||
        hidden_size: int = 2048,
 | 
			
		||||
        hidden_ratio: Optional[int] = 4,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        num_hidden_layers: int = 24,
 | 
			
		||||
        num_heads: int = 32,
 | 
			
		||||
        num_kv_heads: int = None,
 | 
			
		||||
        hidden_act: str = "swish",
 | 
			
		||||
        max_position_embeddings: int = 2048,
 | 
			
		||||
        initializer_range: float = 0.02,
 | 
			
		||||
        elementwise_affine: Optional[bool] = True,
 | 
			
		||||
        norm_eps: float = 1e-6,
 | 
			
		||||
        use_cache: bool = True,
 | 
			
		||||
        pad_token_id: int = None,
 | 
			
		||||
        bos_token_id: int = 1,
 | 
			
		||||
        eos_token_id: int = 2,
 | 
			
		||||
        tie_word_embeddings: bool = False,
 | 
			
		||||
        attention_bias: bool = False,
 | 
			
		||||
        fuse_norm: bool = True,
 | 
			
		||||
        fuse_cross_entropy: bool = True,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
        self.num_hidden_layers = num_hidden_layers
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.num_kv_heads = num_kv_heads
 | 
			
		||||
 | 
			
		||||
        self.hidden_act = hidden_act
 | 
			
		||||
        self.initializer_range = initializer_range
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.norm_eps = norm_eps
 | 
			
		||||
        self.use_cache = use_cache
 | 
			
		||||
        self.attention_bias = attention_bias
 | 
			
		||||
        self.fuse_cross_entropy = fuse_cross_entropy
 | 
			
		||||
        self.fuse_norm = fuse_norm
 | 
			
		||||
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            pad_token_id=pad_token_id,
 | 
			
		||||
            bos_token_id=bos_token_id,
 | 
			
		||||
            eos_token_id=eos_token_id,
 | 
			
		||||
            tie_word_embeddings=tie_word_embeddings,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										522
									
								
								finetune/lora/v6/fla/models/transformer/modeling_transformer.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										522
									
								
								finetune/lora/v6/fla/models/transformer/modeling_transformer.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,522 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
from transformers.cache_utils import Cache, DynamicCache
 | 
			
		||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
 | 
			
		||||
                                           CausalLMOutputWithPast)
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
from fla.models.transformer.configuration_transformer import TransformerConfig
 | 
			
		||||
from fla.modules import FusedCrossEntropyLoss, RMSNorm, RotaryEmbedding
 | 
			
		||||
from fla.modules.activations import swiglu_linear
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from flash_attn import flash_attn_func, flash_attn_varlen_func
 | 
			
		||||
    from flash_attn.bert_padding import (index_first_axis, pad_input,
 | 
			
		||||
                                         unpad_input)
 | 
			
		||||
except ImportError:
 | 
			
		||||
    warnings.warn("Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`")
 | 
			
		||||
    flash_attn_func = None
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: TransformerConfig,
 | 
			
		||||
        layer_idx: Optional[int] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
 | 
			
		||||
        self.num_heads = config.num_heads
 | 
			
		||||
        if config.num_kv_heads is None:
 | 
			
		||||
            self.num_kv_heads = self.num_heads
 | 
			
		||||
        else:
 | 
			
		||||
            self.num_kv_heads = config.num_kv_heads
 | 
			
		||||
        self.num_kv_groups = config.num_heads // self.num_kv_heads
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
        self.head_dim = self.hidden_size // self.num_heads
 | 
			
		||||
        self.kv_dim = self.num_kv_heads * self.head_dim
 | 
			
		||||
        self.kv_dim = self.num_kv_heads * self.head_dim
 | 
			
		||||
        self.max_position_embeddings = config.max_position_embeddings
 | 
			
		||||
 | 
			
		||||
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
 | 
			
		||||
        self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
 | 
			
		||||
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        self.rotary = RotaryEmbedding(self.head_dim)
 | 
			
		||||
 | 
			
		||||
        self.apply(self._initialize_weights)
 | 
			
		||||
 | 
			
		||||
    def _initialize_weights(self, module: nn.Module):
 | 
			
		||||
        if getattr(module, "_is_hf_initialized", False):
 | 
			
		||||
            return
 | 
			
		||||
        if isinstance(module, nn.Linear):
 | 
			
		||||
            nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        module._is_hf_initialized = True
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Cache] = None,
 | 
			
		||||
        output_attentions: bool = False,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
			
		||||
        batch_size, q_len, _ = hidden_states.size()
 | 
			
		||||
        q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
 | 
			
		||||
        k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads)
 | 
			
		||||
        v = rearrange(self.v_proj(hidden_states), 'b t (h d) -> b h t d', h=self.num_kv_heads)
 | 
			
		||||
 | 
			
		||||
        seqlen_offset = 0
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            # to deliminate the offsets of padding tokens
 | 
			
		||||
            seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
 | 
			
		||||
        q, k = self.rotary(q, k, seqlen_offset, self.max_position_embeddings)
 | 
			
		||||
 | 
			
		||||
        k = rearrange(k, 'b t h d -> b h t d')
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            k, v = past_key_values.update(k, v, self.layer_idx)
 | 
			
		||||
        k, v = rearrange(k, 'b h t d -> b t h d'), rearrange(v, 'b h t d -> b t h d')
 | 
			
		||||
        if self.num_kv_groups > 1:
 | 
			
		||||
            k = rearrange(k.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d')
 | 
			
		||||
            v = rearrange(v.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d')
 | 
			
		||||
 | 
			
		||||
        if flash_attn_func is None:
 | 
			
		||||
            raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
 | 
			
		||||
 | 
			
		||||
        # Contains at least one padding token in the sequence
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
 | 
			
		||||
            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
 | 
			
		||||
            max_seqlen_q, max_seqlen_k = max_seq_lens
 | 
			
		||||
            o = flash_attn_varlen_func(
 | 
			
		||||
                q, k, v,
 | 
			
		||||
                cu_seqlens_q=cu_seqlens_q,
 | 
			
		||||
                cu_seqlens_k=cu_seqlens_k,
 | 
			
		||||
                max_seqlen_q=max_seqlen_q,
 | 
			
		||||
                max_seqlen_k=max_seqlen_k,
 | 
			
		||||
                causal=True
 | 
			
		||||
            )
 | 
			
		||||
            o = pad_input(o, indices_q, batch_size, q_len)
 | 
			
		||||
        else:
 | 
			
		||||
            o = flash_attn_func(q, k, v, causal=True)
 | 
			
		||||
        o = o.reshape(batch_size, q_len, self.hidden_size)
 | 
			
		||||
        o = self.o_proj(o)
 | 
			
		||||
 | 
			
		||||
        if not output_attentions:
 | 
			
		||||
            attentions = None
 | 
			
		||||
 | 
			
		||||
        return o, attentions, past_key_values
 | 
			
		||||
 | 
			
		||||
    def _upad_input(self, q, k, v, attention_mask, q_len):
 | 
			
		||||
        seqlens = attention_mask.sum(-1, dtype=torch.int32)
 | 
			
		||||
        indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
 | 
			
		||||
        max_seqlen_k = seqlens.max().item()
 | 
			
		||||
        cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
 | 
			
		||||
        batch_size, seq_len, num_key_value_heads, head_dim = k.shape
 | 
			
		||||
 | 
			
		||||
        k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
 | 
			
		||||
        v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
 | 
			
		||||
        if q_len == seq_len:
 | 
			
		||||
            q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
 | 
			
		||||
            cu_seqlens_q = cu_seqlens_k
 | 
			
		||||
            max_seqlen_q = max_seqlen_k
 | 
			
		||||
            indices_q = indices_k
 | 
			
		||||
        elif q_len == 1:
 | 
			
		||||
            max_seqlen_q = 1
 | 
			
		||||
            # There is a memcpy here, that is very bad.
 | 
			
		||||
            cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
 | 
			
		||||
            indices_q = cu_seqlens_q[:-1]
 | 
			
		||||
            q = q.squeeze(1)
 | 
			
		||||
        else:
 | 
			
		||||
            # The -q_len: slice assumes left padding.
 | 
			
		||||
            attention_mask = attention_mask[:, -q_len:]
 | 
			
		||||
            q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
 | 
			
		||||
 | 
			
		||||
        return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerMLP(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        hidden_ratio: Optional[int] = None,
 | 
			
		||||
        intermediate_size: Optional[int] = None,
 | 
			
		||||
        hidden_act: str = 'swish'
 | 
			
		||||
    ) -> TransformerMLP:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        # the final number of params is `hidden_ratio * hidden_size^2`
 | 
			
		||||
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
 | 
			
		||||
        if hidden_ratio is None:
 | 
			
		||||
            hidden_ratio = 4
 | 
			
		||||
        if intermediate_size is None:
 | 
			
		||||
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
 | 
			
		||||
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
 | 
			
		||||
        self.hidden_ratio = hidden_ratio
 | 
			
		||||
        self.intermediate_size = intermediate_size
 | 
			
		||||
 | 
			
		||||
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
 | 
			
		||||
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 | 
			
		||||
        self.act_fn = ACT2FN[hidden_act]
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        y = self.gate_proj(x)
 | 
			
		||||
        gate, y = y.chunk(2, -1)
 | 
			
		||||
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerBlock(nn.Module):
 | 
			
		||||
    def __init__(self, config: TransformerConfig, layer_idx: int):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.attn = TransformerAttention(
 | 
			
		||||
            config=config,
 | 
			
		||||
            layer_idx=layer_idx
 | 
			
		||||
        )
 | 
			
		||||
        self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
        self.mlp = TransformerMLP(
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            hidden_ratio=config.hidden_ratio,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            hidden_act=config.hidden_act
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = False,
 | 
			
		||||
        use_cache: Optional[bool] = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 | 
			
		||||
 | 
			
		||||
        residual = hidden_states
 | 
			
		||||
        hidden_states = self.attn_norm(hidden_states)
 | 
			
		||||
        hidden_states, attentions, past_key_values = self.attn(
 | 
			
		||||
            hidden_states=hidden_states,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions
 | 
			
		||||
        )
 | 
			
		||||
        hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
 | 
			
		||||
        hidden_states = self.mlp(hidden_states)
 | 
			
		||||
        hidden_states = residual + hidden_states
 | 
			
		||||
 | 
			
		||||
        outputs = (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            outputs += (attentions,)
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            outputs += (past_key_values,)
 | 
			
		||||
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerPreTrainedModel(PreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    config_class = TransformerConfig
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ['TransformerBlock']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _init_weights(
 | 
			
		||||
        self,
 | 
			
		||||
        module: nn.Module,
 | 
			
		||||
        rescale_prenorm_residual: bool = True,
 | 
			
		||||
        num_residuals_per_layer: int = 2,
 | 
			
		||||
    ):
 | 
			
		||||
        if isinstance(module, (nn.Linear, nn.Conv1d)):
 | 
			
		||||
            # Slightly different from the TF version which uses truncated_normal for initialization
 | 
			
		||||
            # cf https://github.com/pytorch/pytorch/pull/5617
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.bias is not None:
 | 
			
		||||
                nn.init.zeros_(module.bias)
 | 
			
		||||
        elif isinstance(module, nn.Embedding):
 | 
			
		||||
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
 | 
			
		||||
            if module.padding_idx is not None:
 | 
			
		||||
                module.weight.data[module.padding_idx].zero_()
 | 
			
		||||
 | 
			
		||||
        if rescale_prenorm_residual:
 | 
			
		||||
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
			
		||||
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
			
		||||
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
			
		||||
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
			
		||||
            #
 | 
			
		||||
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
			
		||||
            for name, p in module.named_parameters():
 | 
			
		||||
                if name in ["o_proj.weight", "down_proj.weight"]:
 | 
			
		||||
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
			
		||||
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
			
		||||
                    # We need to reinit p since this code could be called multiple times
 | 
			
		||||
                    # Having just p *= scale would repeatedly scale it down
 | 
			
		||||
                    with torch.no_grad():
 | 
			
		||||
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerModel(TransformerPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: TransformerConfig):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
 | 
			
		||||
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
 | 
			
		||||
 | 
			
		||||
        self.gradient_checkpointing = False
 | 
			
		||||
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
 | 
			
		||||
            )
 | 
			
		||||
            output_attentions = False
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is None and inputs_embeds is None:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            use_legacy_cache = not isinstance(past_key_values, Cache)
 | 
			
		||||
            if use_legacy_cache:
 | 
			
		||||
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embeddings(input_ids)
 | 
			
		||||
 | 
			
		||||
        # embed positions
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_attns = () if output_attentions else None
 | 
			
		||||
        next_decoder_cache = None
 | 
			
		||||
 | 
			
		||||
        for layer in self.layers:
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                layer_outputs = self._gradient_checkpointing_func(
 | 
			
		||||
                    layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    output_attentions,
 | 
			
		||||
                    use_cache
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                layer_outputs = layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=attention_mask,
 | 
			
		||||
                    past_key_values=past_key_values,
 | 
			
		||||
                    output_attentions=output_attentions,
 | 
			
		||||
                    use_cache=use_cache
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = None
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
 | 
			
		||||
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_attns
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerForCausalLM(TransformerPreTrainedModel):
 | 
			
		||||
    _tied_weights_keys = ["lm_head.weight"]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.model = TransformerModel(config)
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embeddings
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embeddings = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        past_key_values: Optional[torch.Tensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        # only last token for `inputs_ids` if the `past_key_values` is passed along.
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            input_ids = input_ids[:, -1:]
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {'inputs_embeds': inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
 | 
			
		||||
            # recompiles graphs as the stride of the inputs is a guard.
 | 
			
		||||
            # Ref: https://github.com/huggingface/transformers/pull/29114
 | 
			
		||||
            # TODO: use `next_tokens` directly instead.
 | 
			
		||||
            model_inputs = {'input_ids': input_ids.contiguous()}
 | 
			
		||||
 | 
			
		||||
        model_inputs.update({
 | 
			
		||||
            'past_key_values': past_key_values,
 | 
			
		||||
            'use_cache': kwargs.get('use_cache'),
 | 
			
		||||
            'attention_mask': attention_mask,
 | 
			
		||||
        })
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = outputs[0]
 | 
			
		||||
        logits = self.lm_head(hidden_states)
 | 
			
		||||
 | 
			
		||||
        loss = None
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            if self.config.fuse_cross_entropy:
 | 
			
		||||
                loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
 | 
			
		||||
            else:
 | 
			
		||||
                loss_fct = nn.CrossEntropyLoss()
 | 
			
		||||
            # Enable model parallelism
 | 
			
		||||
            labels = labels.to(logits.device)
 | 
			
		||||
            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
 | 
			
		||||
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            output = (logits,) + outputs[1:]
 | 
			
		||||
            return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=logits,
 | 
			
		||||
            past_key_values=outputs.past_key_values,
 | 
			
		||||
            hidden_states=outputs.hidden_states,
 | 
			
		||||
            attentions=outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										107
									
								
								finetune/lora/v6/fla/models/utils.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								finetune/lora/v6/fla/models/utils.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,107 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RecurrentCache(Cache):
 | 
			
		||||
    """
 | 
			
		||||
    A cache used for storing hidden states produced by flash linear attention models.
 | 
			
		||||
 | 
			
		||||
    It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        seen_tokens: int = 0
 | 
			
		||||
    ) -> RecurrentCache:
 | 
			
		||||
 | 
			
		||||
        self.states: List[torch.Tensor] = []
 | 
			
		||||
        self._seen_tokens = seen_tokens  # Used in `generate` to keep tally of how many tokens the cache has seen
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, layer_idx: int) -> torch.Tensor:
 | 
			
		||||
        if layer_idx < len(self):
 | 
			
		||||
            return self.states[layer_idx]
 | 
			
		||||
        else:
 | 
			
		||||
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
 | 
			
		||||
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for state in self.states:
 | 
			
		||||
            yield state
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.states)
 | 
			
		||||
 | 
			
		||||
    def update(
 | 
			
		||||
        self,
 | 
			
		||||
        state: Tuple[torch.Tensor],
 | 
			
		||||
        layer_idx: int,
 | 
			
		||||
        offset: Optional[int] = 1,
 | 
			
		||||
        cache_kwargs: Optional[Dict[str, Any]] = None,
 | 
			
		||||
    ) -> Tuple[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Updates the cache with the new `state` for the layer `layer_idx`.
 | 
			
		||||
 | 
			
		||||
        Parameters:
 | 
			
		||||
            state (`Tuple[torch.Tensor]`):
 | 
			
		||||
                The new state to cache.
 | 
			
		||||
            layer_idx (`int`):
 | 
			
		||||
                The index of the layer to cache the states for.
 | 
			
		||||
            offset (`int`):
 | 
			
		||||
                The offset of current fed tokens.
 | 
			
		||||
            cache_kwargs (`Dict[str, Any]`, `optional`):
 | 
			
		||||
                Additional arguments for the cache subclass.
 | 
			
		||||
 | 
			
		||||
        Return:
 | 
			
		||||
            The updated state.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if isinstance(state, torch.Tensor):
 | 
			
		||||
            state = (state,)
 | 
			
		||||
        if len(self.states) <= layer_idx:
 | 
			
		||||
            self.states.append(state)
 | 
			
		||||
        else:
 | 
			
		||||
            for i, s in enumerate(state):
 | 
			
		||||
                self.states[layer_idx][i].copy_(s)
 | 
			
		||||
            # update the number of seen tokens once we achieve the last layer
 | 
			
		||||
            if layer_idx == len(self) - 1:
 | 
			
		||||
                self._seen_tokens += offset
 | 
			
		||||
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 | 
			
		||||
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 | 
			
		||||
        if len(self.states) <= layer_idx:
 | 
			
		||||
            return 0
 | 
			
		||||
        return self._seen_tokens
 | 
			
		||||
 | 
			
		||||
    def get_max_length(self) -> Optional[int]:
 | 
			
		||||
        """Returns the maximum sequence length of the cached states. RecurrentCache does not have a maximum length."""
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def reorder_cache(self, beam_idx: torch.LongTensor):
 | 
			
		||||
        """Reorders the cache for beam search, given the selected beam indices."""
 | 
			
		||||
        for layer_idx in range(len(self.states)):
 | 
			
		||||
            device = self.states[layer_idx].device
 | 
			
		||||
            self.states[layer_idx] = self.states[layer_idx].index_select(0, beam_idx.to(device))
 | 
			
		||||
 | 
			
		||||
    def to_legacy_cache(self) -> Tuple[torch.Tensor]:
 | 
			
		||||
        return tuple(self.states)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_legacy_cache(
 | 
			
		||||
        cls,
 | 
			
		||||
        past_key_values: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
        seen_tokens: int = 0
 | 
			
		||||
    ) -> RecurrentCache:
 | 
			
		||||
        """Converts a cache in the legacy cache format into an equivalent `RecurrentCache`."""
 | 
			
		||||
 | 
			
		||||
        cache = cls(seen_tokens)
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            for layer_idx in range(len(past_key_values)):
 | 
			
		||||
                cache.update(past_key_values[layer_idx], layer_idx)
 | 
			
		||||
        return cache
 | 
			
		||||
							
								
								
									
										20
									
								
								finetune/lora/v6/fla/modules/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								finetune/lora/v6/fla/modules/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,20 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution,
 | 
			
		||||
                                     ShortConvolution)
 | 
			
		||||
from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
 | 
			
		||||
from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate,
 | 
			
		||||
                                         FusedLayerNormSwishGateLinear,
 | 
			
		||||
                                         FusedRMSNormSwishGate,
 | 
			
		||||
                                         FusedRMSNormSwishGateLinear)
 | 
			
		||||
from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm,
 | 
			
		||||
                                   RMSNormLinear)
 | 
			
		||||
from fla.modules.rotary import RotaryEmbedding
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
 | 
			
		||||
    'FusedCrossEntropyLoss',
 | 
			
		||||
    'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
 | 
			
		||||
    'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
 | 
			
		||||
    'RotaryEmbedding'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										394
									
								
								finetune/lora/v6/fla/modules/activations.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										394
									
								
								finetune/lora/v6/fla/modules/activations.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,394 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2023-2024, Tri Dao, Yu Zhang, Songlin Yang.
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
sigmoid_fwd_codestring = """
 | 
			
		||||
template <typename T> T sigmoid_fwd(T x) {
 | 
			
		||||
    return 1.0f / (1.0f + ::exp(-float(x)));
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
sigmoid_bwd_codestring = """
 | 
			
		||||
template <typename T> T sigmoid_bwd(T x, T g) {
 | 
			
		||||
    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
 | 
			
		||||
    return float(g) * x_sigmoid * (1.0f - x_sigmoid);
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
sigmoid_fwd = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring)
 | 
			
		||||
sigmoid_bwd = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SigmoidFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, x):
 | 
			
		||||
        ctx.save_for_backward(x)
 | 
			
		||||
        return sigmoid_fwd(x)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, dout):
 | 
			
		||||
        x, = ctx.saved_tensors
 | 
			
		||||
        return sigmoid_bwd(x, dout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
sigmoid = SigmoidFunction.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({'BT': 16}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 16}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 16}, num_warps=8),
 | 
			
		||||
        triton.Config({'BT': 32}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 32}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 32}, num_warps=8),
 | 
			
		||||
        triton.Config({'BT': 64}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 64}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 64}, num_warps=8),
 | 
			
		||||
        triton.Config({'BT': 128}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 128}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 128}, num_warps=8),
 | 
			
		||||
        triton.Config({'BT': 256}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 256}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 256}, num_warps=8)
 | 
			
		||||
    ],
 | 
			
		||||
    key=['D']
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def logsigmoid_fwd_kernel(
 | 
			
		||||
    x,
 | 
			
		||||
    y,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    D: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i = tl.program_id(0)
 | 
			
		||||
    o_i = i * BT + tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    p_x = x + o_i
 | 
			
		||||
    p_y = y + o_i
 | 
			
		||||
    mask = o_i < T
 | 
			
		||||
 | 
			
		||||
    # [D,]
 | 
			
		||||
    b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)
 | 
			
		||||
    b_m = tl.minimum(0., b_x)
 | 
			
		||||
    b_z = 1. + tl.exp(-tl.abs(b_x))
 | 
			
		||||
    b_y = b_m - tl.log(b_z)
 | 
			
		||||
    tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({'BT': 16}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 16}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 16}, num_warps=8),
 | 
			
		||||
        triton.Config({'BT': 32}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 32}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 32}, num_warps=8),
 | 
			
		||||
        triton.Config({'BT': 64}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 64}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 64}, num_warps=8),
 | 
			
		||||
        triton.Config({'BT': 128}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 128}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 128}, num_warps=8),
 | 
			
		||||
        triton.Config({'BT': 256}, num_warps=2),
 | 
			
		||||
        triton.Config({'BT': 256}, num_warps=4),
 | 
			
		||||
        triton.Config({'BT': 256}, num_warps=8)
 | 
			
		||||
    ],
 | 
			
		||||
    key=['D']
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def logsigmoid_bwd_kernel(
 | 
			
		||||
    x,
 | 
			
		||||
    dx,
 | 
			
		||||
    dy,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    D: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i = tl.program_id(0)
 | 
			
		||||
    o_i = i * BT + tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    p_x = x + o_i
 | 
			
		||||
    p_dx = dx + o_i
 | 
			
		||||
    p_dy = dy + o_i
 | 
			
		||||
    mask = o_i < T
 | 
			
		||||
 | 
			
		||||
    # [D,]
 | 
			
		||||
    b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)
 | 
			
		||||
    b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32)
 | 
			
		||||
    b_dx = b_dy * (1. - tl.sigmoid(b_x))
 | 
			
		||||
    tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogSigmoidFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(ctx, x):
 | 
			
		||||
        T, D = x.numel(), x.shape[-1]
 | 
			
		||||
        y = torch.empty_like(x)
 | 
			
		||||
        logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D)
 | 
			
		||||
        ctx.save_for_backward(x,)
 | 
			
		||||
        return y
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, dy):
 | 
			
		||||
        x, = ctx.saved_tensors
 | 
			
		||||
        T, D = x.numel(), x.shape[-1]
 | 
			
		||||
        dx = torch.empty_like(x)
 | 
			
		||||
        logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D)
 | 
			
		||||
        return dx
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logsigmoid = LogSigmoidFunction.apply
 | 
			
		||||
 | 
			
		||||
swish_fwd_codestring = """
 | 
			
		||||
template <typename T> T swish_fwd(T x) {
 | 
			
		||||
    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
 | 
			
		||||
    return float(x) * x_sigmoid;
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
swish_bwd_codestring = """
 | 
			
		||||
template <typename T> T swish_bwd(T x, T g) {
 | 
			
		||||
    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
 | 
			
		||||
    return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x));
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
swish_fwd = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring)
 | 
			
		||||
swish_bwd = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SwishFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, x):
 | 
			
		||||
        ctx.save_for_backward(x)
 | 
			
		||||
        return swish_fwd(x)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, dout):
 | 
			
		||||
        x, = ctx.saved_tensors
 | 
			
		||||
        return swish_bwd(x, dout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
swish = SwishFunction.apply
 | 
			
		||||
 | 
			
		||||
# 1/sqrt(2*pi)-> 0.3989423
 | 
			
		||||
# 1/sqrt(2)   -> 0.70710678
 | 
			
		||||
# sqrt(2/pi)  -> 0.79788456
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# this function is tanh approximation of gelu
 | 
			
		||||
# actual gelu is:
 | 
			
		||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def bias_gelu(y, bias):
 | 
			
		||||
    x = bias + y
 | 
			
		||||
    return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# gradient of tanh approximation of gelu
 | 
			
		||||
# gradient of actual gelu is:
 | 
			
		||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def bias_gelu_bwd(g, y, bias):
 | 
			
		||||
    """Assume that y has shape (B, D) and bias has shape (D)"""
 | 
			
		||||
    x = bias + y
 | 
			
		||||
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
 | 
			
		||||
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
 | 
			
		||||
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
 | 
			
		||||
        1 + tanh_out
 | 
			
		||||
    )
 | 
			
		||||
    grad_y = ff * g
 | 
			
		||||
    return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GeLUFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    # bias is an optional argument
 | 
			
		||||
    def forward(ctx, input, bias):
 | 
			
		||||
        ctx.save_for_backward(input, bias)
 | 
			
		||||
        return bias_gelu(input, bias)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_output):
 | 
			
		||||
        input, bias = ctx.saved_tensors
 | 
			
		||||
        tmp = bias_gelu_bwd(grad_output, input, bias)
 | 
			
		||||
        return tmp, tmp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
bias_gelu_impl = GeLUFunction.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# this function is tanh approximation of gelu
 | 
			
		||||
# actual gelu is:
 | 
			
		||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def gelu_fwd(x):
 | 
			
		||||
    return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# gradient of tanh approximation of gelu
 | 
			
		||||
# gradient of actual gelu is:
 | 
			
		||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def gelu_bwd(g, x):
 | 
			
		||||
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
 | 
			
		||||
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
 | 
			
		||||
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
 | 
			
		||||
        1 + tanh_out
 | 
			
		||||
    )
 | 
			
		||||
    return (ff * g).to(dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FastGeLUFunction(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    # bias is an optional argument
 | 
			
		||||
    def forward(ctx, input):
 | 
			
		||||
        ctx.save_for_backward(input)
 | 
			
		||||
        return gelu_fwd(input)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_output):
 | 
			
		||||
        (input,) = ctx.saved_tensors
 | 
			
		||||
        tmp = gelu_bwd(grad_output, input)
 | 
			
		||||
        return tmp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
fast_gelu_impl = FastGeLUFunction.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def relu_bwd(g, x):
 | 
			
		||||
    return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def sqrelu_fwd(x):
 | 
			
		||||
    r = F.relu(x)
 | 
			
		||||
    return (r * r).to(dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def sqrelu_bwd(g, x):
 | 
			
		||||
    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SquaredReLUFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, input):
 | 
			
		||||
        ctx.save_for_backward(input)
 | 
			
		||||
        return sqrelu_fwd(input)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_output):
 | 
			
		||||
        input, = ctx.saved_tensors
 | 
			
		||||
        return sqrelu_bwd(grad_output, input)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
sqrelu = SquaredReLUFunction.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
swiglu_fwd_codestring = """
 | 
			
		||||
template <typename T> T swiglu_fwd(T x, T y) {
 | 
			
		||||
    return float(x) * float(y) / (1.0f + ::exp(-float(x)));
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
swiglu_bwd_codestring = """
 | 
			
		||||
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
 | 
			
		||||
    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
 | 
			
		||||
    dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
 | 
			
		||||
    dy = float(x) * x_sigmoid * float(g);
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
swiglu_bwd_with_output_codestring = """
 | 
			
		||||
template <typename T> T swiglu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) {
 | 
			
		||||
    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
 | 
			
		||||
    float x_swish = float(x) * x_sigmoid;
 | 
			
		||||
    dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
 | 
			
		||||
    dy = x_swish * float(g);
 | 
			
		||||
    z = x_swish * float(y);
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
 | 
			
		||||
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
 | 
			
		||||
swiglu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_with_output_codestring, num_outputs=3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SwiGLUFunction(torch.autograd.Function):
 | 
			
		||||
    r"""
 | 
			
		||||
    Swish-Gated Linear Unit (SwiGLU) function.
 | 
			
		||||
 | 
			
		||||
    .. math::
 | 
			
		||||
        \text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, x, y):
 | 
			
		||||
        ctx.save_for_backward(x, y)
 | 
			
		||||
        return swiglu_fwd(x, y)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, dout):
 | 
			
		||||
        x, y = ctx.saved_tensors
 | 
			
		||||
        return swiglu_bwd(x, y, dout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SwiGLULinearFunction(torch.autograd.Function):
 | 
			
		||||
    r"""
 | 
			
		||||
    Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation.
 | 
			
		||||
 | 
			
		||||
    .. math::
 | 
			
		||||
        \text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b
 | 
			
		||||
 | 
			
		||||
    This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, x, y, weight, bias):
 | 
			
		||||
        z = swiglu_fwd(x, y)
 | 
			
		||||
        out = F.linear(z.to(weight.dtype), weight, bias)
 | 
			
		||||
        # We don't store z, will be recomputed in the backward pass to save memory
 | 
			
		||||
        ctx.save_for_backward(x, y, weight)
 | 
			
		||||
        ctx.linear_bias_is_none = bias is None
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, dout, *args):
 | 
			
		||||
        x, y, weight = ctx.saved_tensors
 | 
			
		||||
        dout = dout.reshape(-1, dout.shape[-1])
 | 
			
		||||
        dz = F.linear(dout, weight.t()).view_as(x)
 | 
			
		||||
        dx, dy, z = swiglu_bwd_with_output(x, y, dz)
 | 
			
		||||
        dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1]))
 | 
			
		||||
        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
 | 
			
		||||
        return dx, dy, dlinear_weight, dlinear_bias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
swiglu = SwiGLUFunction.apply
 | 
			
		||||
 | 
			
		||||
swiglu_linear = SwiGLULinearFunction.apply
 | 
			
		||||
 | 
			
		||||
ACT2FN = {
 | 
			
		||||
    'relu': F.relu,
 | 
			
		||||
    'sigmoid': sigmoid,
 | 
			
		||||
    'logsigmoid': logsigmoid,
 | 
			
		||||
    'silu': swish,
 | 
			
		||||
    'swish': swish,
 | 
			
		||||
    'sqrelu': sqrelu,
 | 
			
		||||
    'gelu': fast_gelu_impl,
 | 
			
		||||
    'bias_gelu': bias_gelu_impl,
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										336
									
								
								finetune/lora/v6/fla/modules/convolution.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										336
									
								
								finetune/lora/v6/fla/modules/convolution.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,336 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
 | 
			
		||||
from fla.modules.activations import ACT2FN
 | 
			
		||||
from fla.utils import checkpoint
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
 | 
			
		||||
except ImportError:
 | 
			
		||||
    causal_conv1d_fn = None
 | 
			
		||||
    causal_conv1d_update = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
 | 
			
		||||
    seqlen = u.shape[-1]
 | 
			
		||||
    fft_size = 2 * seqlen
 | 
			
		||||
    k_f = torch.fft.rfft(k, n=fft_size) / fft_size
 | 
			
		||||
    if k_rev is not None:
 | 
			
		||||
        k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
 | 
			
		||||
        k_f = k_f + k_rev_f.conj()
 | 
			
		||||
    u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
 | 
			
		||||
 | 
			
		||||
    if len(u.shape) > 3:
 | 
			
		||||
        k_f = k_f.unsqueeze(1)
 | 
			
		||||
    y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
 | 
			
		||||
 | 
			
		||||
    out = y + u
 | 
			
		||||
    if gelu:
 | 
			
		||||
        out = F.gelu(out)
 | 
			
		||||
    if dropout_mask is not None:
 | 
			
		||||
        return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
 | 
			
		||||
    else:
 | 
			
		||||
        return out.to(dtype=u.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@checkpoint
 | 
			
		||||
def proj_then_conv1d(
 | 
			
		||||
    x: torch.Tensor,
 | 
			
		||||
    proj_weight: torch.Tensor,
 | 
			
		||||
    conv1d_weight: torch.Tensor,
 | 
			
		||||
    conv1d_bias: Optional[torch.Tensor] = None,
 | 
			
		||||
    cache: Optional[torch.Tensor] = None
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    # We do matmul and transpose BLH -> HBL at the same time
 | 
			
		||||
    x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2])
 | 
			
		||||
 | 
			
		||||
    if causal_conv1d_fn is None:
 | 
			
		||||
        raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
 | 
			
		||||
    if cache is None:
 | 
			
		||||
        x = causal_conv1d_fn(
 | 
			
		||||
            x=x,
 | 
			
		||||
            weight=rearrange(conv1d_weight, "d 1 w -> d w"),
 | 
			
		||||
            bias=conv1d_bias,
 | 
			
		||||
            activation="silu",
 | 
			
		||||
        ).transpose(1, 2)
 | 
			
		||||
    else:
 | 
			
		||||
        assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
 | 
			
		||||
        x = x.squeeze(-1)
 | 
			
		||||
        x = causal_conv1d_update(
 | 
			
		||||
            x=x,
 | 
			
		||||
            weight=rearrange(conv1d_weight, "d 1 w -> d w"),
 | 
			
		||||
            bias=conv1d_bias,
 | 
			
		||||
            cache=cache,
 | 
			
		||||
            activation="silu",
 | 
			
		||||
        )
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ShortConvolution(nn.Conv1d):
 | 
			
		||||
    """
 | 
			
		||||
    Simple wrapper around `nn.Conv1d` that accepts dimension last.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        kernel_size: int,
 | 
			
		||||
        bias: bool = False,
 | 
			
		||||
        activation: Optional[str] = 'silu',
 | 
			
		||||
        use_causal_conv: Optional[bool] = True
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(in_channels=hidden_size,
 | 
			
		||||
                         out_channels=hidden_size,
 | 
			
		||||
                         kernel_size=kernel_size,
 | 
			
		||||
                         groups=hidden_size,
 | 
			
		||||
                         bias=bias,
 | 
			
		||||
                         padding=kernel_size - 1)
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.activation = None
 | 
			
		||||
        if activation is not None:
 | 
			
		||||
            assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
 | 
			
		||||
            self.activation = activation
 | 
			
		||||
 | 
			
		||||
        if use_causal_conv:
 | 
			
		||||
            if causal_conv1d_fn is None:
 | 
			
		||||
                warnings.warn("Please install `causal-conv1d` to use causal convolutions, setting `use_causal_conv` to False.")
 | 
			
		||||
                use_causal_conv = False
 | 
			
		||||
        self.use_causal_conv = use_causal_conv
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self):
 | 
			
		||||
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
 | 
			
		||||
             ', stride={stride}')
 | 
			
		||||
        if self.padding != (0,) * len(self.padding):
 | 
			
		||||
            s += ', padding={padding}'
 | 
			
		||||
        if self.dilation != (1,) * len(self.dilation):
 | 
			
		||||
            s += ', dilation={dilation}'
 | 
			
		||||
        if self.output_padding != (0,) * len(self.output_padding):
 | 
			
		||||
            s += ', output_padding={output_padding}'
 | 
			
		||||
        if self.groups != 1:
 | 
			
		||||
            s += ', groups={groups}'
 | 
			
		||||
        if self.bias is None:
 | 
			
		||||
            s += ', bias=False'
 | 
			
		||||
        if self.padding_mode != 'zeros':
 | 
			
		||||
            s += ', padding_mode={padding_mode}'
 | 
			
		||||
        if self.activation is not None:
 | 
			
		||||
            s += ', activation={activation}'
 | 
			
		||||
        if not self.use_causal_conv:
 | 
			
		||||
            s += ', use_causal_conv={use_causal_conv}'
 | 
			
		||||
        return s.format(**self.__dict__)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        x: torch.Tensor,
 | 
			
		||||
        mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        cache: Optional[torch.Tensor] = None
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            x (`torch.Tensor`):
 | 
			
		||||
                Tensor of shape `[batch_size, seq_len, hidden_size]`
 | 
			
		||||
            mask (`Optional[torch.Tensor]`):
 | 
			
		||||
                Attention mask dealing with padded positions.
 | 
			
		||||
            cache (`Optional[torch.Tensor]`):
 | 
			
		||||
                Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`,
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tensor of shape `[batch_size, seq_len, hidden_size]`. The `cache` (if provided) is updated inplace.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if mask is not None:
 | 
			
		||||
            x = x.mul_(mask.unsqueeze(-1))
 | 
			
		||||
        if cache is not None and x.shape[1] == 1:
 | 
			
		||||
            return self.step(x, cache)
 | 
			
		||||
        x = rearrange(x, "b l d -> b d l")
 | 
			
		||||
        # Update state (B D W)
 | 
			
		||||
        if cache is not None:
 | 
			
		||||
            cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0)))
 | 
			
		||||
        if self.use_causal_conv:
 | 
			
		||||
            x = causal_conv1d_fn(
 | 
			
		||||
                x=x,
 | 
			
		||||
                weight=rearrange(self.weight, "d 1 w -> d w"),
 | 
			
		||||
                bias=self.bias,
 | 
			
		||||
                activation=self.activation,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
 | 
			
		||||
            if self.activation is not None:
 | 
			
		||||
                x = ACT2FN[self.activation](x)
 | 
			
		||||
        return rearrange(x, "b d l -> b l d")
 | 
			
		||||
 | 
			
		||||
    def step(
 | 
			
		||||
        self,
 | 
			
		||||
        x: torch.Tensor,
 | 
			
		||||
        cache: torch.Tensor
 | 
			
		||||
    ):
 | 
			
		||||
        assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now"
 | 
			
		||||
 | 
			
		||||
        x = x.squeeze(1)
 | 
			
		||||
        if self.use_causal_conv:
 | 
			
		||||
            x = causal_conv1d_update(
 | 
			
		||||
                x=x,
 | 
			
		||||
                conv_state=cache,
 | 
			
		||||
                weight=rearrange(self.weight, "d 1 w -> d w"),
 | 
			
		||||
                bias=self.bias,
 | 
			
		||||
                activation=self.activation,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            dtype = x.dtype
 | 
			
		||||
            cache.copy_(torch.roll(cache, shifts=-1, dims=-1))
 | 
			
		||||
            cache[:, :, -1] = x
 | 
			
		||||
            x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                x = x + self.bias
 | 
			
		||||
            if self.activation is not None:
 | 
			
		||||
                x = ACT2FN[self.activation](x).to(dtype=dtype)
 | 
			
		||||
        return x.unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def state_size(self) -> int:
 | 
			
		||||
        return self.hidden_size * self.kernel_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LongConvolution(nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    LongConvolution applies a convolution operation on the input tensor using a fixed
 | 
			
		||||
    filter of length l_max.
 | 
			
		||||
    The filter is learned during training and is applied using FFT convolution.
 | 
			
		||||
    Args:
 | 
			
		||||
        hidden_size (int): The number of expected features in the input and output.
 | 
			
		||||
        l_max (int): The maximum sequence length.
 | 
			
		||||
    Returns:
 | 
			
		||||
        y: (b, l, d) tensor
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        l_max: int,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Initializes the LongConvolution module.
 | 
			
		||||
        Args:
 | 
			
		||||
            hidden_size (int): The number of expected features in the input and output.
 | 
			
		||||
            l_max (int): The maximum sequence length.
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.filter = nn.Parameter(torch.randn(self.hidden_size, l_max), requires_grad=True)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor, *args, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Applies the LongConvolution operation on the input tensor.
 | 
			
		||||
        Args:
 | 
			
		||||
            x: (b, l, d) tensor
 | 
			
		||||
        Returns:
 | 
			
		||||
            y: (b, l, d) tensor
 | 
			
		||||
        """
 | 
			
		||||
        x = x.transpose(1, 2)
 | 
			
		||||
        y = fft_conv(x, self.filter, dropout_mask=None, gelu=False)
 | 
			
		||||
        y = y.transpose(1, 2)
 | 
			
		||||
        return y.to(dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PositionalEmbedding(nn.Module):
 | 
			
		||||
    def __init__(self, emb_dim: int, seq_len: int, **kwargs):
 | 
			
		||||
        """Complex exponential positional embeddings for implicit long convolution filters."""
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.seq_len = seq_len
 | 
			
		||||
        # The time embedding fed to the filteres is normalized so that t_f = 1
 | 
			
		||||
        t = torch.linspace(0, 1, self.seq_len)[None, :, None]  # 1, L, 1
 | 
			
		||||
 | 
			
		||||
        if emb_dim > 1:
 | 
			
		||||
            bands = (emb_dim - 1) // 2
 | 
			
		||||
        # To compute the right embeddings we use the "proper" linspace
 | 
			
		||||
        t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
 | 
			
		||||
        w = 2 * math.pi * t_rescaled / seq_len  # 1, L, 1
 | 
			
		||||
 | 
			
		||||
        f = torch.linspace(1e-4, bands - 1, bands)[None, None]
 | 
			
		||||
        z = torch.exp(-1j * f * w)
 | 
			
		||||
        z = torch.cat([t, z.real, z.imag], dim=-1)
 | 
			
		||||
        self.z = nn.Parameter(z, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
    def forward(self, L):
 | 
			
		||||
        return self.z[:, :L]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImplicitLongConvolution(nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Long convolution with implicit filter parameterized by an MLP.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        hidden_size (int):
 | 
			
		||||
            The number of expected features in the input and output.
 | 
			
		||||
        l_max (int):
 | 
			
		||||
            The maximum sequence length.
 | 
			
		||||
        d_emb (Optional[int]):
 | 
			
		||||
            The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine).
 | 
			
		||||
            Defaults to 3.
 | 
			
		||||
        d_hidden (Optional[int]):
 | 
			
		||||
            The number of features in the hidden layer of the MLP. Defaults to 16.
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        pos_emb (`PositionalEmbedding`): The positional embedding layer.
 | 
			
		||||
        mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        l_max: int,
 | 
			
		||||
        d_emb: int = 3,
 | 
			
		||||
        d_hidden: int = 16,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Long convolution with implicit filter parameterized by an MLP.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.d_emb = d_emb
 | 
			
		||||
 | 
			
		||||
        assert (
 | 
			
		||||
            d_emb % 2 != 0 and d_emb >= 3
 | 
			
		||||
        ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)"
 | 
			
		||||
        self.pos_emb = PositionalEmbedding(d_emb, l_max)
 | 
			
		||||
 | 
			
		||||
        # final linear layer
 | 
			
		||||
        self.mlp = nn.Sequential(
 | 
			
		||||
            nn.Linear(d_emb, d_hidden),
 | 
			
		||||
            torch.nn.ReLU(),
 | 
			
		||||
            nn.Linear(d_hidden, hidden_size),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def filter(self, seq_len: int, *args, **kwargs):
 | 
			
		||||
        k = self.mlp(self.pos_emb(seq_len))
 | 
			
		||||
 | 
			
		||||
        return k.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor, *args, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            x: (b, l, d) tensor
 | 
			
		||||
        Returns:
 | 
			
		||||
            y: (b, l, d) tensor
 | 
			
		||||
        """
 | 
			
		||||
        x = x.transpose(1, 2)
 | 
			
		||||
        k = self.filter(x.shape[-1])
 | 
			
		||||
        y = fft_conv(x, k, dropout_mask=None, gelu=False)
 | 
			
		||||
 | 
			
		||||
        y = y.transpose(1, 2)
 | 
			
		||||
        return y.to(dtype=x.dtype)
 | 
			
		||||
							
								
								
									
										235
									
								
								finetune/lora/v6/fla/modules/feature_map.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								finetune/lora/v6/fla/modules/feature_map.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,235 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from fla.modules.layernorm import layer_norm_fn
 | 
			
		||||
from fla.utils import checkpoint
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@checkpoint
 | 
			
		||||
def flatten_diag_outer_product(x, y):
 | 
			
		||||
    z = torch.einsum("...i,...j->...ij", x, y)
 | 
			
		||||
    N = z.size(-1)
 | 
			
		||||
    indicies = torch.triu_indices(N, N)
 | 
			
		||||
    return z[..., indicies[0], indicies[1]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@checkpoint
 | 
			
		||||
def flatten_diag_outer_product_off1(x, y):
 | 
			
		||||
    z = torch.einsum("...i,...j->...ij", x, y)
 | 
			
		||||
    N = z.size(-1)
 | 
			
		||||
    indicies = torch.triu_indices(N, N, 1)
 | 
			
		||||
    indices2 = torch.arange(0, N)
 | 
			
		||||
    return z[..., indicies[0], indicies[1]], z[..., indices2, indices2]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_power_of_2(n):
 | 
			
		||||
    return (n & (n - 1) == 0) and n != 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HedgehogFeatureMap(nn.Module):
 | 
			
		||||
 | 
			
		||||
    r"""
 | 
			
		||||
    Hedgehog feature map as introduced in
 | 
			
		||||
    `The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry <https://arxiv.org/abs/2402.04347>`_
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        head_dim: int
 | 
			
		||||
    ) -> HedgehogFeatureMap:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # Trainable map
 | 
			
		||||
        self.layer = nn.Linear(head_dim, head_dim)
 | 
			
		||||
        self.init_weights_()
 | 
			
		||||
 | 
			
		||||
    def init_weights_(self):
 | 
			
		||||
        """Initialize trainable map as identity"""
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float)
 | 
			
		||||
            self.layer.weight.copy_(identity.to(self.layer.weight))
 | 
			
		||||
        nn.init.zeros_(self.layer.bias)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        x = self.layer(x)  # shape b, h, l, d
 | 
			
		||||
        return torch.cat([2*x, -2*x], dim=-1).softmax(-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class T2RFeatureMap(nn.Module):
 | 
			
		||||
 | 
			
		||||
    r"""
 | 
			
		||||
    Simple linear mapping feature map as in
 | 
			
		||||
    `Finetuning Pretrained Transformers into RNNs <https://arxiv.org/abs/2103.13076>`_
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        head_dim: int,
 | 
			
		||||
        dot_dim: int = None
 | 
			
		||||
    ) -> T2RFeatureMap:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # Trainable map
 | 
			
		||||
        if dot_dim is None:
 | 
			
		||||
            dot_dim = head_dim
 | 
			
		||||
        self.layer = nn.Linear(head_dim, dot_dim)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        return self.layer(x).relu()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DPFPFeatureMap(nn.Module):
 | 
			
		||||
 | 
			
		||||
    r"""
 | 
			
		||||
    Deterministic Parameter-Free Projection (DPFP) feature map in
 | 
			
		||||
    `Linear Transformers Are Secretly Fast Weight Programmers <https://arxiv.org/abs/2102.11174>`_
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        head_dim: int,
 | 
			
		||||
        nu: int = 4
 | 
			
		||||
    ) -> DPFPFeatureMap:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.nu = nu
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        x = torch.cat([x.relu(), -x.relu()], dim=-1)
 | 
			
		||||
        x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1)
 | 
			
		||||
        x_repeat = torch.cat([x] * self.nu, dim=-1)
 | 
			
		||||
        return x_repeat * x_rolled
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HadamardFeatureMap(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        head_dim: int
 | 
			
		||||
    ) -> HadamardFeatureMap:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # Trainable map
 | 
			
		||||
        self.layer1 = nn.Linear(head_dim, head_dim)
 | 
			
		||||
        self.layer2 = nn.Linear(head_dim, head_dim)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        return self.layer1(x) * self.layer2(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LearnableOuterProductFeatureMap(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        head_dim: int,
 | 
			
		||||
        feature_dim: int
 | 
			
		||||
    ) -> LearnableOuterProductFeatureMap:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # Trainable map
 | 
			
		||||
        self.layer1 = nn.Linear(head_dim, feature_dim, bias=False)
 | 
			
		||||
        self.layer2 = nn.Linear(head_dim, feature_dim, bias=False)
 | 
			
		||||
        self.normalizer = feature_dim ** -0.5
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        return flatten_diag_outer_product(self.layer1(x), self.layer2(x))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LearnablePolySketchNonNegativeFeatureMap(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        head_dim: int,
 | 
			
		||||
        sketch_size: Optional[int] = None,
 | 
			
		||||
        degree: Optional[int] = 2
 | 
			
		||||
    ) -> LearnablePolySketchNonNegativeFeatureMap:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2"
 | 
			
		||||
 | 
			
		||||
        self.head_dim = head_dim
 | 
			
		||||
        self.sketch_size = sketch_size if sketch_size is not None else head_dim
 | 
			
		||||
        self.degree = degree
 | 
			
		||||
 | 
			
		||||
        self.gamma = nn.Parameter(torch.ones(head_dim))
 | 
			
		||||
        self.beta = nn.Parameter(torch.zeros(head_dim))
 | 
			
		||||
        # NOTE: the sketch layers defined here are quite different from the original paper
 | 
			
		||||
        # currently we simply use linear layers without any non-linear activations
 | 
			
		||||
        self.sketches1 = nn.ModuleList([
 | 
			
		||||
            nn.Linear(head_dim, sketch_size, bias=False),
 | 
			
		||||
            *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
 | 
			
		||||
        ])
 | 
			
		||||
        self.sketches2 = nn.ModuleList([
 | 
			
		||||
            nn.Linear(head_dim, sketch_size, bias=False),
 | 
			
		||||
            *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        # Section 2.1
 | 
			
		||||
        x = layer_norm_fn(x, self.gamma, self.beta)
 | 
			
		||||
        # first map the input to sketch size with learnable parameters
 | 
			
		||||
        x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5
 | 
			
		||||
        for i in range(1, int(math.log2(self.degree)) - 1):
 | 
			
		||||
            x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5
 | 
			
		||||
        # do sketch mapping for log2(p) - 1 times in total
 | 
			
		||||
        # do p=2 mapping to ensure non-negativity
 | 
			
		||||
        return flatten_diag_outer_product(x, x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TaylorFeatureMap(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        head_dim: int
 | 
			
		||||
    ) -> TaylorFeatureMap:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.head_dim = head_dim
 | 
			
		||||
        self.r2 = math.sqrt(2)
 | 
			
		||||
        self.rd = math.sqrt(self.head_dim)
 | 
			
		||||
        self.rrd = math.sqrt(self.rd)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
 | 
			
		||||
        return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RebasedFeatureMap(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        head_dim: int,
 | 
			
		||||
        use_gamma: Optional[bool] = True,
 | 
			
		||||
        use_beta: Optional[bool] = True,
 | 
			
		||||
        normalize: Optional[bool] = True
 | 
			
		||||
    ) -> RebasedFeatureMap:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.head_dim = head_dim
 | 
			
		||||
        self.use_gamma = use_gamma
 | 
			
		||||
        self.use_beta = use_beta
 | 
			
		||||
        self.normalize = normalize
 | 
			
		||||
 | 
			
		||||
        self.gamma = None
 | 
			
		||||
        self.beta = None
 | 
			
		||||
        if use_gamma:
 | 
			
		||||
            self.gamma = nn.Parameter(torch.ones(head_dim))
 | 
			
		||||
        if use_beta:
 | 
			
		||||
            self.beta = nn.Parameter(torch.zeros(head_dim))
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor, flatten: Optional[bool] = True):
 | 
			
		||||
        if self.use_beta and self.use_gamma and self.normalize:
 | 
			
		||||
            x = layer_norm_fn(x, self.gamma, self.beta)
 | 
			
		||||
        elif self.normalize:
 | 
			
		||||
            x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta)
 | 
			
		||||
        elif self.use_gamma and self.use_beta:
 | 
			
		||||
            x = torch.addcmul(self.beta, x, self.gamma)
 | 
			
		||||
        elif self.use_gamma:
 | 
			
		||||
            x = x.mul(self.gamma)
 | 
			
		||||
        else:
 | 
			
		||||
            raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, "
 | 
			
		||||
                               f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)")
 | 
			
		||||
        if not flatten:
 | 
			
		||||
            return x
 | 
			
		||||
        x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
 | 
			
		||||
        # rebased use learnable parameters to approximate any quadratic function
 | 
			
		||||
        return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1)
 | 
			
		||||
							
								
								
									
										398
									
								
								finetune/lora/v6/fla/modules/fused_cross_entropy.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										398
									
								
								finetune/lora/v6/fla/modules/fused_cross_entropy.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,398 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2023, Tri Dao.
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
 | 
			
		||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
 | 
			
		||||
# version of PyTorch. The following 2 lines are for backward compatibility with
 | 
			
		||||
# older PyTorch.
 | 
			
		||||
if "all_gather_into_tensor" not in dir(torch.distributed):
 | 
			
		||||
    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.heuristics(
 | 
			
		||||
    {
 | 
			
		||||
        "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
 | 
			
		||||
    }
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def cross_entropy_fwd_kernel(
 | 
			
		||||
    loss_ptr,  # data ptrs
 | 
			
		||||
    lse_ptr,
 | 
			
		||||
    z_loss_ptr,
 | 
			
		||||
    logits_ptr,
 | 
			
		||||
    labels_ptr,
 | 
			
		||||
    smoothing,
 | 
			
		||||
    logit_scale,
 | 
			
		||||
    lse_square_scale,
 | 
			
		||||
    ignored_index,
 | 
			
		||||
    total_classes,
 | 
			
		||||
    class_start_idx,  # Useful for tensor parallel when each rank only has a subset of classes
 | 
			
		||||
    n_cols,  # shapes
 | 
			
		||||
    n_rows,
 | 
			
		||||
    logits_row_stride,  # strides
 | 
			
		||||
    BLOCK_SIZE: tl.constexpr,
 | 
			
		||||
    HAS_SMOOTHING: tl.constexpr,
 | 
			
		||||
    # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
 | 
			
		||||
    SPLIT: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    row_idx = tl.program_id(0)
 | 
			
		||||
    col_block_idx = tl.program_id(1)
 | 
			
		||||
    logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
 | 
			
		||||
    col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
 | 
			
		||||
    label_idx = tl.load(labels_ptr + row_idx)
 | 
			
		||||
    logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
 | 
			
		||||
        tl.float32
 | 
			
		||||
    ) * logit_scale
 | 
			
		||||
    max_logits = tl.max(logits, 0)
 | 
			
		||||
    if HAS_SMOOTHING:
 | 
			
		||||
        sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
 | 
			
		||||
    lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
 | 
			
		||||
    tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
 | 
			
		||||
    if label_idx == ignored_index:
 | 
			
		||||
        loss = 0.0
 | 
			
		||||
        z_loss = 0.0
 | 
			
		||||
    else:
 | 
			
		||||
        label_idx -= class_start_idx
 | 
			
		||||
        if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
 | 
			
		||||
            n_cols, (col_block_idx + 1) * BLOCK_SIZE
 | 
			
		||||
        ):
 | 
			
		||||
            logits_label = tl.load(logits_ptr + label_idx) * logit_scale
 | 
			
		||||
            if HAS_SMOOTHING:
 | 
			
		||||
                loss = (
 | 
			
		||||
                    (lse if not SPLIT else 0.0)
 | 
			
		||||
                    - smoothing * sum_logits / total_classes
 | 
			
		||||
                    - (1 - smoothing) * logits_label
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                loss = (lse if not SPLIT else 0.0) - logits_label
 | 
			
		||||
        else:
 | 
			
		||||
            # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
 | 
			
		||||
            if HAS_SMOOTHING:
 | 
			
		||||
                loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
 | 
			
		||||
            else:
 | 
			
		||||
                loss = 0.0
 | 
			
		||||
        if not SPLIT:
 | 
			
		||||
            z_loss = lse_square_scale * lse * lse
 | 
			
		||||
            loss += z_loss
 | 
			
		||||
        else:
 | 
			
		||||
            z_loss = 0.0
 | 
			
		||||
    tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
 | 
			
		||||
    if not SPLIT:
 | 
			
		||||
        tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.heuristics(
 | 
			
		||||
    {
 | 
			
		||||
        "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
 | 
			
		||||
    }
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def cross_entropy_bwd_kernel(
 | 
			
		||||
    dlogits_ptr,  # data ptrs
 | 
			
		||||
    dloss_ptr,
 | 
			
		||||
    logits_ptr,
 | 
			
		||||
    lse_ptr,
 | 
			
		||||
    labels_ptr,
 | 
			
		||||
    smoothing,
 | 
			
		||||
    logit_scale,
 | 
			
		||||
    lse_square_scale,
 | 
			
		||||
    ignored_index,
 | 
			
		||||
    total_classes,
 | 
			
		||||
    class_start_idx,  # Useful for tensor parallel when each rank only has a subset of classes
 | 
			
		||||
    n_cols,  # shapes
 | 
			
		||||
    logits_row_stride,  # strides
 | 
			
		||||
    dlogits_row_stride,
 | 
			
		||||
    dloss_row_stride,
 | 
			
		||||
    BLOCK_SIZE: tl.constexpr,
 | 
			
		||||
    HAS_SMOOTHING: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    row_idx = tl.program_id(0)
 | 
			
		||||
    col_block_idx = tl.program_id(1)
 | 
			
		||||
    logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
 | 
			
		||||
    dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
 | 
			
		||||
    col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
 | 
			
		||||
    label_idx = tl.load(labels_ptr + row_idx)
 | 
			
		||||
    if label_idx != ignored_index:
 | 
			
		||||
        dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
 | 
			
		||||
    else:
 | 
			
		||||
        dloss = 0.0
 | 
			
		||||
    logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
 | 
			
		||||
        tl.float32
 | 
			
		||||
    ) * logit_scale
 | 
			
		||||
    lse = tl.load(lse_ptr + row_idx)
 | 
			
		||||
    probs = tl.exp(logits - lse)
 | 
			
		||||
    probs += 2.0 * lse_square_scale * lse * probs
 | 
			
		||||
    label_idx -= class_start_idx
 | 
			
		||||
    if HAS_SMOOTHING:
 | 
			
		||||
        smooth_negative = smoothing / total_classes
 | 
			
		||||
        probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
 | 
			
		||||
    else:
 | 
			
		||||
        probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
 | 
			
		||||
    tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CrossEntropyLossFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx,
 | 
			
		||||
        logits,
 | 
			
		||||
        labels,
 | 
			
		||||
        smoothing=0.0,
 | 
			
		||||
        logit_scale=1.0,
 | 
			
		||||
        lse_square_scale=0.0,
 | 
			
		||||
        ignored_index=-100,
 | 
			
		||||
        inplace_backward=False,
 | 
			
		||||
        process_group=None,
 | 
			
		||||
    ):
 | 
			
		||||
        n_rows, n_cols = logits.shape
 | 
			
		||||
        assert labels.shape == (n_rows,)
 | 
			
		||||
        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
 | 
			
		||||
        total_classes = world_size * n_cols
 | 
			
		||||
        rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
 | 
			
		||||
        class_start_idx = rank * n_cols
 | 
			
		||||
 | 
			
		||||
        if logits.stride(-1) != 1:
 | 
			
		||||
            logits = logits.contiguous()
 | 
			
		||||
        # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
 | 
			
		||||
        MAX_BLOCK_SIZE = 64 * 1024
 | 
			
		||||
        BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
 | 
			
		||||
        num_warps = (
 | 
			
		||||
            4
 | 
			
		||||
            if BLOCK_SIZE < 2048
 | 
			
		||||
            else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
 | 
			
		||||
        )
 | 
			
		||||
        # We may split the lse computation across multiple blocks, then do a reduction
 | 
			
		||||
        # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
 | 
			
		||||
        # where having just one thread block processing more than 64k elements is slow.
 | 
			
		||||
        split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
 | 
			
		||||
        n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
 | 
			
		||||
        loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
 | 
			
		||||
        losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
 | 
			
		||||
        lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
 | 
			
		||||
        z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
 | 
			
		||||
        # Need this, otherwise Triton tries to launch from cuda:0 and we get
 | 
			
		||||
        # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
 | 
			
		||||
        with torch.cuda.device(logits.device.index):
 | 
			
		||||
            cross_entropy_fwd_kernel[(n_rows, n_splits)](
 | 
			
		||||
                losses,  # data ptrs
 | 
			
		||||
                lse,
 | 
			
		||||
                z_losses,
 | 
			
		||||
                logits,
 | 
			
		||||
                labels,
 | 
			
		||||
                smoothing,
 | 
			
		||||
                logit_scale,
 | 
			
		||||
                lse_square_scale,
 | 
			
		||||
                ignored_index,
 | 
			
		||||
                total_classes,
 | 
			
		||||
                class_start_idx,
 | 
			
		||||
                n_cols,  # shapes
 | 
			
		||||
                n_rows,
 | 
			
		||||
                logits.stride(0),  # strides
 | 
			
		||||
                BLOCK_SIZE=BLOCK_SIZE,  # constants
 | 
			
		||||
                num_warps=num_warps,
 | 
			
		||||
                SPLIT=split,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if split:
 | 
			
		||||
            # If there's no smoothing, if labels are in the vocab of this partition, losses contains
 | 
			
		||||
            # - predicted logit, and 0 otherwise.
 | 
			
		||||
            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
 | 
			
		||||
            # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
 | 
			
		||||
            # For labels not in the vocab of this partition, losses contains
 | 
			
		||||
            # -0.1 * sum logit / total_classes.
 | 
			
		||||
            if n_splits > 1:
 | 
			
		||||
                lse = torch.logsumexp(lse, dim=0)
 | 
			
		||||
                losses = losses.sum(dim=0)
 | 
			
		||||
            if world_size > 1:
 | 
			
		||||
                lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
 | 
			
		||||
                torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
 | 
			
		||||
                handle_losses = torch.distributed.all_reduce(
 | 
			
		||||
                    losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
 | 
			
		||||
                )
 | 
			
		||||
                lse = torch.logsumexp(lse_allgather, dim=0)
 | 
			
		||||
                handle_losses.wait()
 | 
			
		||||
            # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
 | 
			
		||||
            # we just have to add the (global) lse.
 | 
			
		||||
            # If there's smoothing=0.1, the total losses are
 | 
			
		||||
            # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
 | 
			
		||||
            # Again, we just have to add the (global) lse.
 | 
			
		||||
            losses += lse
 | 
			
		||||
            if lse_square_scale != 0.0:
 | 
			
		||||
                z_losses = lse_square_scale * lse.square()
 | 
			
		||||
                z_losses.masked_fill_(labels == ignored_index, 0.0)
 | 
			
		||||
                losses += z_losses
 | 
			
		||||
            else:
 | 
			
		||||
                z_losses = torch.zeros_like(losses)
 | 
			
		||||
            losses.masked_fill_(labels == ignored_index, 0.0)
 | 
			
		||||
 | 
			
		||||
        ctx.save_for_backward(logits, lse, labels)
 | 
			
		||||
        ctx.mark_non_differentiable(z_losses)
 | 
			
		||||
        ctx.smoothing = smoothing
 | 
			
		||||
        ctx.logit_scale = logit_scale
 | 
			
		||||
        ctx.lse_square_scale = lse_square_scale
 | 
			
		||||
        ctx.ignored_index = ignored_index
 | 
			
		||||
        ctx.total_classes = total_classes
 | 
			
		||||
        ctx.class_start_idx = class_start_idx
 | 
			
		||||
        ctx.inplace_backward = inplace_backward
 | 
			
		||||
 | 
			
		||||
        return losses, z_losses
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_losses, grad_z_losses):
 | 
			
		||||
        del grad_z_losses  # z_losses are only for logging.
 | 
			
		||||
 | 
			
		||||
        logits, lse, labels = ctx.saved_tensors
 | 
			
		||||
        dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
 | 
			
		||||
        n_rows, n_cols = logits.shape
 | 
			
		||||
        BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
 | 
			
		||||
        num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
 | 
			
		||||
        def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"]))  # noqa
 | 
			
		||||
        # Need this, otherwise Triton tries to launch from cuda:0 and we get
 | 
			
		||||
        # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
 | 
			
		||||
        with torch.cuda.device(logits.device.index):
 | 
			
		||||
            cross_entropy_bwd_kernel[grid](
 | 
			
		||||
                dlogits,  # data ptrs
 | 
			
		||||
                grad_losses,
 | 
			
		||||
                logits,
 | 
			
		||||
                lse,
 | 
			
		||||
                labels,
 | 
			
		||||
                ctx.smoothing,
 | 
			
		||||
                ctx.logit_scale,
 | 
			
		||||
                ctx.lse_square_scale,
 | 
			
		||||
                ctx.ignored_index,
 | 
			
		||||
                ctx.total_classes,
 | 
			
		||||
                ctx.class_start_idx,
 | 
			
		||||
                n_cols,  # shapes
 | 
			
		||||
                logits.stride(0),  # strides
 | 
			
		||||
                dlogits.stride(0),
 | 
			
		||||
                grad_losses.stride(0),
 | 
			
		||||
                BLOCK_SIZE=BLOCK_SIZE,  # constants
 | 
			
		||||
                num_warps=num_warps,
 | 
			
		||||
            )
 | 
			
		||||
        return dlogits, None, None, None, None, None, None, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cross_entropy_loss(
 | 
			
		||||
    logits: torch.Tensor,
 | 
			
		||||
    labels: torch.Tensor,
 | 
			
		||||
    label_smoothing: float = 0.0,
 | 
			
		||||
    logit_scale: float = 1.0,
 | 
			
		||||
    lse_square_scale: float = 0.0,
 | 
			
		||||
    ignored_index=-100,
 | 
			
		||||
    inplace_backward: bool = False,
 | 
			
		||||
    process_group=None,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments:
 | 
			
		||||
        logits: (batch, vocab_size)
 | 
			
		||||
        labels: (batch,)
 | 
			
		||||
        label_smoothing: float
 | 
			
		||||
        logit_scale: float. Multiply logits by this scale before calculating the loss.
 | 
			
		||||
        lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
 | 
			
		||||
            This is also referred to as "z-loss".
 | 
			
		||||
        ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
 | 
			
		||||
        inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
 | 
			
		||||
            This saves memory.
 | 
			
		||||
        process_group: if not None, we're doing Tensor Parallel: each process is responsible for
 | 
			
		||||
            one part of the vocab. The loss will be aggregated across processes.
 | 
			
		||||
    Returns:
 | 
			
		||||
        losses: (batch,), float
 | 
			
		||||
        z_losses: (batch,), float
 | 
			
		||||
    """
 | 
			
		||||
    return CrossEntropyLossFunction.apply(
 | 
			
		||||
        logits,
 | 
			
		||||
        labels,
 | 
			
		||||
        label_smoothing,
 | 
			
		||||
        logit_scale,
 | 
			
		||||
        lse_square_scale,
 | 
			
		||||
        ignored_index,
 | 
			
		||||
        inplace_backward,
 | 
			
		||||
        process_group,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedCrossEntropyLoss(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        ignore_index=-100,
 | 
			
		||||
        reduction="mean",
 | 
			
		||||
        label_smoothing=0.0,
 | 
			
		||||
        logit_scale=1.0,
 | 
			
		||||
        lse_square_scale=0.0,
 | 
			
		||||
        inplace_backward=False,
 | 
			
		||||
        process_group=None,
 | 
			
		||||
        return_z_loss=False,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Arguments:
 | 
			
		||||
            ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
 | 
			
		||||
            label_smoothing: float
 | 
			
		||||
            lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
 | 
			
		||||
                This is also referred to as "z-loss".
 | 
			
		||||
            inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
 | 
			
		||||
                This saves memory.
 | 
			
		||||
            process_group: if not None, we're doing Tensor Parallel: each process is responsible for
 | 
			
		||||
                one part of the vocab. The loss will be aggregated across processes.
 | 
			
		||||
            return_z_loss: bool. If True, we return the component of the loss contributed by
 | 
			
		||||
                the lse_square_scale value. This value is only for logging and does not support
 | 
			
		||||
                backprop.
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        if reduction not in ["mean", "none", "sum"]:
 | 
			
		||||
            raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
 | 
			
		||||
        self.ignore_index = ignore_index
 | 
			
		||||
        self.reduction = reduction
 | 
			
		||||
        self.label_smoothing = label_smoothing
 | 
			
		||||
        self.logit_scale = logit_scale
 | 
			
		||||
        self.lse_square_scale = lse_square_scale
 | 
			
		||||
        self.inplace_backward = inplace_backward
 | 
			
		||||
        self.process_group = process_group
 | 
			
		||||
        self.return_z_loss = return_z_loss
 | 
			
		||||
 | 
			
		||||
    def forward(self, input, target):
 | 
			
		||||
        """
 | 
			
		||||
        Arguments:
 | 
			
		||||
            input: (batch, vocab_size)
 | 
			
		||||
            target: (batch,)
 | 
			
		||||
        Returns:
 | 
			
		||||
            losses: (batch,) if reduction is 'none', else (1,), dtype float
 | 
			
		||||
            z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
 | 
			
		||||
        """
 | 
			
		||||
        assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
 | 
			
		||||
        loss, z_loss = cross_entropy_loss(
 | 
			
		||||
            input,
 | 
			
		||||
            target,
 | 
			
		||||
            label_smoothing=self.label_smoothing,
 | 
			
		||||
            logit_scale=self.logit_scale,
 | 
			
		||||
            lse_square_scale=self.lse_square_scale,
 | 
			
		||||
            ignored_index=self.ignore_index,
 | 
			
		||||
            inplace_backward=self.inplace_backward,
 | 
			
		||||
            process_group=self.process_group,
 | 
			
		||||
        )
 | 
			
		||||
        if self.reduction == "mean":
 | 
			
		||||
            loss = loss.sum() / (target != self.ignore_index).sum()
 | 
			
		||||
        elif self.reduction == "sum":
 | 
			
		||||
            loss = loss.sum()
 | 
			
		||||
        else:
 | 
			
		||||
            loss = loss
 | 
			
		||||
 | 
			
		||||
        if not self.return_z_loss:
 | 
			
		||||
            return loss
 | 
			
		||||
 | 
			
		||||
        if self.reduction == "mean":
 | 
			
		||||
            z_loss = z_loss.sum() / (target != self.ignore_index).sum()
 | 
			
		||||
        elif self.reduction == "sum":
 | 
			
		||||
            z_loss = z_loss.sum()
 | 
			
		||||
        else:
 | 
			
		||||
            z_loss = z_loss
 | 
			
		||||
 | 
			
		||||
        return loss, z_loss
 | 
			
		||||
							
								
								
									
										889
									
								
								finetune/lora/v6/fla/modules/fused_norm_gate.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										889
									
								
								finetune/lora/v6/fla/modules/fused_norm_gate.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,889 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2023, Tri Dao.
 | 
			
		||||
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
 | 
			
		||||
# Implement residual + layer_norm / rms_norm.
 | 
			
		||||
 | 
			
		||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
 | 
			
		||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
 | 
			
		||||
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
 | 
			
		||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
 | 
			
		||||
    dtype = x.dtype
 | 
			
		||||
    if upcast:
 | 
			
		||||
        weight = weight.float()
 | 
			
		||||
        bias = bias.float() if bias is not None else None
 | 
			
		||||
    if upcast:
 | 
			
		||||
        x = x.float()
 | 
			
		||||
        residual = residual.float() if residual is not None else residual
 | 
			
		||||
    if residual is not None:
 | 
			
		||||
        x = (x + residual).to(x.dtype)
 | 
			
		||||
    out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
 | 
			
		||||
        dtype
 | 
			
		||||
    )
 | 
			
		||||
    return out if not prenorm else (out, x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
 | 
			
		||||
    dtype = x.dtype
 | 
			
		||||
    if upcast:
 | 
			
		||||
        weight = weight.float()
 | 
			
		||||
        bias = bias.float() if bias is not None else None
 | 
			
		||||
    if upcast:
 | 
			
		||||
        x = x.float()
 | 
			
		||||
        residual = residual.float() if residual is not None else residual
 | 
			
		||||
    if residual is not None:
 | 
			
		||||
        x = (x + residual).to(x.dtype)
 | 
			
		||||
    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
 | 
			
		||||
    out = (x * rstd * weight) + \
 | 
			
		||||
        bias if bias is not None else (x * rstd * weight)
 | 
			
		||||
    out = out.to(dtype)
 | 
			
		||||
    return out if not prenorm else (out, x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
 | 
			
		||||
)
 | 
			
		||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
 | 
			
		||||
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _layer_norm_fwd_1pass_kernel(
 | 
			
		||||
    X,  # pointer to the input
 | 
			
		||||
    O,  # pointer to the gate
 | 
			
		||||
    Y,  # pointer to the output
 | 
			
		||||
    W,  # pointer to the weights
 | 
			
		||||
    B,  # pointer to the biases
 | 
			
		||||
    RESIDUAL,  # pointer to the residual
 | 
			
		||||
    RESIDUAL_OUT,  # pointer to the residual
 | 
			
		||||
    Mean,  # pointer to the mean
 | 
			
		||||
    Rstd,  # pointer to the 1/std
 | 
			
		||||
    stride_x_row,  # how much to increase the pointer when moving by 1 row
 | 
			
		||||
    stride_y_row,
 | 
			
		||||
    stride_res_row,
 | 
			
		||||
    stride_res_out_row,
 | 
			
		||||
    N,  # number of columns in X
 | 
			
		||||
    eps,  # epsilon to avoid division by zero
 | 
			
		||||
    IS_RMS_NORM: tl.constexpr,
 | 
			
		||||
    BLOCK_N: tl.constexpr,
 | 
			
		||||
    HAS_RESIDUAL: tl.constexpr,
 | 
			
		||||
    STORE_RESIDUAL_OUT: tl.constexpr,
 | 
			
		||||
    HAS_WEIGHT: tl.constexpr,
 | 
			
		||||
    HAS_BIAS: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    # Map the program id to the row of X and Y it should compute.
 | 
			
		||||
    row = tl.program_id(0)
 | 
			
		||||
    X += row * stride_x_row
 | 
			
		||||
    Y += row * stride_y_row
 | 
			
		||||
    O += row * stride_x_row
 | 
			
		||||
    if HAS_RESIDUAL:
 | 
			
		||||
        RESIDUAL += row * stride_res_row
 | 
			
		||||
    if STORE_RESIDUAL_OUT:
 | 
			
		||||
        RESIDUAL_OUT += row * stride_res_out_row
 | 
			
		||||
    # Compute mean and variance
 | 
			
		||||
    cols = tl.arange(0, BLOCK_N)
 | 
			
		||||
    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
 | 
			
		||||
    if HAS_RESIDUAL:
 | 
			
		||||
        residual = tl.load(RESIDUAL + cols, mask=cols <
 | 
			
		||||
                           N, other=0.0).to(tl.float32)
 | 
			
		||||
        x += residual
 | 
			
		||||
    if STORE_RESIDUAL_OUT:
 | 
			
		||||
        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
 | 
			
		||||
    if not IS_RMS_NORM:
 | 
			
		||||
        mean = tl.sum(x, axis=0) / N
 | 
			
		||||
        tl.store(Mean + row, mean)
 | 
			
		||||
        xbar = tl.where(cols < N, x - mean, 0.0)
 | 
			
		||||
        var = tl.sum(xbar * xbar, axis=0) / N
 | 
			
		||||
    else:
 | 
			
		||||
        xbar = tl.where(cols < N, x, 0.0)
 | 
			
		||||
        var = tl.sum(xbar * xbar, axis=0) / N
 | 
			
		||||
    rstd = 1 / tl.sqrt(var + eps)
 | 
			
		||||
    tl.store(Rstd + row, rstd)
 | 
			
		||||
    # Normalize and apply linear transformation
 | 
			
		||||
    mask = cols < N
 | 
			
		||||
    if HAS_WEIGHT:
 | 
			
		||||
        w = tl.load(W + cols, mask=mask).to(tl.float32)
 | 
			
		||||
    if HAS_BIAS:
 | 
			
		||||
        b = tl.load(B + cols, mask=mask).to(tl.float32)
 | 
			
		||||
    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
 | 
			
		||||
    y = x_hat * w if HAS_WEIGHT else x_hat
 | 
			
		||||
    if HAS_BIAS:
 | 
			
		||||
        y = y + b
 | 
			
		||||
 | 
			
		||||
    # Swish output gate
 | 
			
		||||
    o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)
 | 
			
		||||
    y = y * o * tl.sigmoid(o)
 | 
			
		||||
 | 
			
		||||
    # Write output
 | 
			
		||||
    tl.store(Y + cols, y, mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _layer_norm_fwd(
 | 
			
		||||
    x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
 | 
			
		||||
):
 | 
			
		||||
    if residual is not None:
 | 
			
		||||
        residual_dtype = residual.dtype
 | 
			
		||||
    M, N = x.shape
 | 
			
		||||
    assert x.stride(-1) == 1
 | 
			
		||||
    if residual is not None:
 | 
			
		||||
        assert residual.stride(-1) == 1
 | 
			
		||||
        assert residual.shape == (M, N)
 | 
			
		||||
    if weight is not None:
 | 
			
		||||
        assert weight.shape == (N,)
 | 
			
		||||
        assert weight.stride(-1) == 1
 | 
			
		||||
    if bias is not None:
 | 
			
		||||
        assert bias.stride(-1) == 1
 | 
			
		||||
        assert bias.shape == (N,)
 | 
			
		||||
    # allocate output
 | 
			
		||||
    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
 | 
			
		||||
    assert y.stride(-1) == 1
 | 
			
		||||
    if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
 | 
			
		||||
        residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
 | 
			
		||||
        assert residual_out.stride(-1) == 1
 | 
			
		||||
    else:
 | 
			
		||||
        residual_out = None
 | 
			
		||||
    mean = torch.empty((M,), dtype=torch.float32,
 | 
			
		||||
                       device="cuda") if not is_rms_norm else None
 | 
			
		||||
    rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
 | 
			
		||||
    # Less than 64KB per feature: enqueue fused kernel
 | 
			
		||||
    MAX_FUSED_SIZE = 65536 // x.element_size()
 | 
			
		||||
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
 | 
			
		||||
    if N > BLOCK_N:
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            "This layer norm doesn't support feature dim >= 64KB.")
 | 
			
		||||
    # heuristics for number of warps
 | 
			
		||||
    with torch.cuda.device(x.device.index):
 | 
			
		||||
        _layer_norm_fwd_1pass_kernel[(M,)](
 | 
			
		||||
            x,
 | 
			
		||||
            o,
 | 
			
		||||
            y,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            residual,
 | 
			
		||||
            residual_out,
 | 
			
		||||
            mean,
 | 
			
		||||
            rstd,
 | 
			
		||||
            x.stride(0),
 | 
			
		||||
            y.stride(0),
 | 
			
		||||
            residual.stride(0) if residual is not None else 0,
 | 
			
		||||
            residual_out.stride(0) if residual_out is not None else 0,
 | 
			
		||||
            N,
 | 
			
		||||
            eps,
 | 
			
		||||
            is_rms_norm,
 | 
			
		||||
            BLOCK_N,
 | 
			
		||||
            residual is not None,
 | 
			
		||||
            residual_out is not None,
 | 
			
		||||
            weight is not None,
 | 
			
		||||
            bias is not None,
 | 
			
		||||
        )
 | 
			
		||||
    # residual_out is None if residual is None and residual_dtype == input_dtype
 | 
			
		||||
    return y, mean, rstd, residual_out if residual_out is not None else x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
 | 
			
		||||
)
 | 
			
		||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
 | 
			
		||||
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
 | 
			
		||||
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
 | 
			
		||||
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _layer_norm_bwd_kernel(
 | 
			
		||||
    X,  # pointer to the input
 | 
			
		||||
    O,  # pointer to the gate
 | 
			
		||||
    W,  # pointer to the weights
 | 
			
		||||
    B,  # pointer to the biases
 | 
			
		||||
    Y,  # pointer to the output to be recomputed
 | 
			
		||||
    DY,  # pointer to the output gradient
 | 
			
		||||
    DX,  # pointer to the input gradient
 | 
			
		||||
    DO,  # pointer to the gate gradient
 | 
			
		||||
    DW,  # pointer to the partial sum of weights gradient
 | 
			
		||||
    DB,  # pointer to the partial sum of biases gradient
 | 
			
		||||
    DRESIDUAL,
 | 
			
		||||
    DRESIDUAL_IN,
 | 
			
		||||
    Mean,  # pointer to the mean
 | 
			
		||||
    Rstd,  # pointer to the 1/std
 | 
			
		||||
    stride_x_row,  # how much to increase the pointer when moving by 1 row
 | 
			
		||||
    stride_y_row,
 | 
			
		||||
    stride_dy_row,
 | 
			
		||||
    stride_dx_row,
 | 
			
		||||
    stride_dres_row,
 | 
			
		||||
    stride_dres_in_row,
 | 
			
		||||
    M,  # number of rows in X
 | 
			
		||||
    N,  # number of columns in X
 | 
			
		||||
    eps,  # epsilon to avoid division by zero
 | 
			
		||||
    rows_per_program,
 | 
			
		||||
    IS_RMS_NORM: tl.constexpr,
 | 
			
		||||
    BLOCK_N: tl.constexpr,
 | 
			
		||||
    HAS_DRESIDUAL: tl.constexpr,
 | 
			
		||||
    STORE_DRESIDUAL: tl.constexpr,
 | 
			
		||||
    HAS_WEIGHT: tl.constexpr,
 | 
			
		||||
    HAS_BIAS: tl.constexpr,
 | 
			
		||||
    RECOMPUTE_OUTPUT: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    # Map the program id to the elements of X, DX, and DY it should compute.
 | 
			
		||||
    row_block_id = tl.program_id(0)
 | 
			
		||||
    row_start = row_block_id * rows_per_program
 | 
			
		||||
    cols = tl.arange(0, BLOCK_N)
 | 
			
		||||
    mask = cols < N
 | 
			
		||||
    X += row_start * stride_x_row
 | 
			
		||||
    O += row_start * stride_x_row
 | 
			
		||||
    if HAS_DRESIDUAL:
 | 
			
		||||
        DRESIDUAL += row_start * stride_dres_row
 | 
			
		||||
    if STORE_DRESIDUAL:
 | 
			
		||||
        DRESIDUAL_IN += row_start * stride_dres_in_row
 | 
			
		||||
    DY += row_start * stride_dy_row
 | 
			
		||||
    DX += row_start * stride_dx_row
 | 
			
		||||
    DO += row_start * stride_dx_row
 | 
			
		||||
    if RECOMPUTE_OUTPUT:
 | 
			
		||||
        Y += row_start * stride_y_row
 | 
			
		||||
    if HAS_WEIGHT:
 | 
			
		||||
        w = tl.load(W + cols, mask=mask).to(tl.float32)
 | 
			
		||||
        dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
 | 
			
		||||
    if RECOMPUTE_OUTPUT and HAS_BIAS:
 | 
			
		||||
        b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
 | 
			
		||||
    if HAS_BIAS:
 | 
			
		||||
        db = tl.zeros((BLOCK_N,), dtype=tl.float32)
 | 
			
		||||
    row_end = min((row_block_id + 1) * rows_per_program, M)
 | 
			
		||||
    for row in range(row_start, row_end):
 | 
			
		||||
        # Load data to SRAM
 | 
			
		||||
        x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        o = tl.load(O + cols, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
        if not IS_RMS_NORM:
 | 
			
		||||
            mean = tl.load(Mean + row)
 | 
			
		||||
        rstd = tl.load(Rstd + row)
 | 
			
		||||
        # Compute dx
 | 
			
		||||
        xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
 | 
			
		||||
        xhat = tl.where(mask, xhat, 0.0)
 | 
			
		||||
 | 
			
		||||
        y = xhat * w if HAS_WEIGHT else xhat
 | 
			
		||||
        if HAS_BIAS:
 | 
			
		||||
            y = y + b
 | 
			
		||||
        if RECOMPUTE_OUTPUT:
 | 
			
		||||
            tl.store(Y + cols, y, mask=mask)
 | 
			
		||||
 | 
			
		||||
        sigmoid_o = tl.sigmoid(o)
 | 
			
		||||
        do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o))
 | 
			
		||||
        dy = dy * o * sigmoid_o
 | 
			
		||||
        wdy = dy
 | 
			
		||||
        if HAS_WEIGHT:
 | 
			
		||||
            wdy = dy * w
 | 
			
		||||
            dw += dy * xhat
 | 
			
		||||
        if HAS_BIAS:
 | 
			
		||||
            db += dy
 | 
			
		||||
        if not IS_RMS_NORM:
 | 
			
		||||
            c1 = tl.sum(xhat * wdy, axis=0) / N
 | 
			
		||||
            c2 = tl.sum(wdy, axis=0) / N
 | 
			
		||||
            dx = (wdy - (xhat * c1 + c2)) * rstd
 | 
			
		||||
        else:
 | 
			
		||||
            c1 = tl.sum(xhat * wdy, axis=0) / N
 | 
			
		||||
            dx = (wdy - xhat * c1) * rstd
 | 
			
		||||
        if HAS_DRESIDUAL:
 | 
			
		||||
            dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
            dx += dres
 | 
			
		||||
        # Write dx
 | 
			
		||||
        if STORE_DRESIDUAL:
 | 
			
		||||
            tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
 | 
			
		||||
        tl.store(DX + cols, dx, mask=mask)
 | 
			
		||||
        tl.store(DO + cols, do, mask=mask)
 | 
			
		||||
 | 
			
		||||
        X += stride_x_row
 | 
			
		||||
        O += stride_x_row
 | 
			
		||||
        if HAS_DRESIDUAL:
 | 
			
		||||
            DRESIDUAL += stride_dres_row
 | 
			
		||||
        if STORE_DRESIDUAL:
 | 
			
		||||
            DRESIDUAL_IN += stride_dres_in_row
 | 
			
		||||
        if RECOMPUTE_OUTPUT:
 | 
			
		||||
            Y += stride_y_row
 | 
			
		||||
        DY += stride_dy_row
 | 
			
		||||
        DX += stride_dx_row
 | 
			
		||||
        DO += stride_dx_row
 | 
			
		||||
    if HAS_WEIGHT:
 | 
			
		||||
        tl.store(DW + row_block_id * N + cols, dw, mask=mask)
 | 
			
		||||
    if HAS_BIAS:
 | 
			
		||||
        tl.store(DB + row_block_id * N + cols, db, mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _layer_norm_bwd(
 | 
			
		||||
    dy,
 | 
			
		||||
    x,
 | 
			
		||||
    o,
 | 
			
		||||
    weight,
 | 
			
		||||
    bias,
 | 
			
		||||
    eps,
 | 
			
		||||
    mean,
 | 
			
		||||
    rstd,
 | 
			
		||||
    dresidual=None,
 | 
			
		||||
    has_residual=False,
 | 
			
		||||
    is_rms_norm=False,
 | 
			
		||||
    x_dtype=None,
 | 
			
		||||
    recompute_output=False,
 | 
			
		||||
):
 | 
			
		||||
    M, N = x.shape
 | 
			
		||||
    assert x.stride(-1) == 1
 | 
			
		||||
    assert dy.stride(-1) == 1
 | 
			
		||||
    assert dy.shape == (M, N)
 | 
			
		||||
    if dresidual is not None:
 | 
			
		||||
        assert dresidual.stride(-1) == 1
 | 
			
		||||
        assert dresidual.shape == (M, N)
 | 
			
		||||
    if weight is not None:
 | 
			
		||||
        assert weight.shape == (N,)
 | 
			
		||||
        assert weight.stride(-1) == 1
 | 
			
		||||
    if bias is not None:
 | 
			
		||||
        assert bias.stride(-1) == 1
 | 
			
		||||
        assert bias.shape == (N,)
 | 
			
		||||
    # allocate output
 | 
			
		||||
    dx = (
 | 
			
		||||
        torch.empty_like(x)
 | 
			
		||||
        if x_dtype is None
 | 
			
		||||
        else torch.empty(M, N, dtype=x_dtype, device=x.device)
 | 
			
		||||
    )
 | 
			
		||||
    do = (
 | 
			
		||||
        torch.empty_like(o)
 | 
			
		||||
        if x_dtype is None
 | 
			
		||||
        else torch.empty(M, N, dtype=x_dtype, device=x.device)
 | 
			
		||||
    )
 | 
			
		||||
    dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
 | 
			
		||||
    y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
 | 
			
		||||
 | 
			
		||||
    # Less than 64KB per feature: enqueue fused kernel
 | 
			
		||||
    MAX_FUSED_SIZE = 65536 // x.element_size()
 | 
			
		||||
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
 | 
			
		||||
    if N > BLOCK_N:
 | 
			
		||||
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
 | 
			
		||||
    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
 | 
			
		||||
    _dw = (
 | 
			
		||||
        torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
 | 
			
		||||
        if weight is not None
 | 
			
		||||
        else None
 | 
			
		||||
    )
 | 
			
		||||
    _db = (
 | 
			
		||||
        torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
 | 
			
		||||
        if bias is not None
 | 
			
		||||
        else None
 | 
			
		||||
    )
 | 
			
		||||
    rows_per_program = math.ceil(M / sm_count)
 | 
			
		||||
    grid = (sm_count,)
 | 
			
		||||
    with torch.cuda.device(x.device.index):
 | 
			
		||||
        _layer_norm_bwd_kernel[grid](
 | 
			
		||||
            x,
 | 
			
		||||
            o,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            y,
 | 
			
		||||
            dy,
 | 
			
		||||
            dx,
 | 
			
		||||
            do,
 | 
			
		||||
            _dw,
 | 
			
		||||
            _db,
 | 
			
		||||
            dresidual,
 | 
			
		||||
            dresidual_in,
 | 
			
		||||
            mean,
 | 
			
		||||
            rstd,
 | 
			
		||||
            x.stride(0),
 | 
			
		||||
            0 if not recompute_output else y.stride(0),
 | 
			
		||||
            dy.stride(0),
 | 
			
		||||
            dx.stride(0),
 | 
			
		||||
            dresidual.stride(0) if dresidual is not None else 0,
 | 
			
		||||
            dresidual_in.stride(0) if dresidual_in is not None else 0,
 | 
			
		||||
            M,
 | 
			
		||||
            N,
 | 
			
		||||
            eps,
 | 
			
		||||
            rows_per_program,
 | 
			
		||||
            is_rms_norm,
 | 
			
		||||
            BLOCK_N,
 | 
			
		||||
            dresidual is not None,
 | 
			
		||||
            dresidual_in is not None,
 | 
			
		||||
            weight is not None,
 | 
			
		||||
            bias is not None,
 | 
			
		||||
        )
 | 
			
		||||
    dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
 | 
			
		||||
    db = _db.sum(0).to(bias.dtype) if bias is not None else None
 | 
			
		||||
    # Don't need to compute dresidual_in separately in this case
 | 
			
		||||
    if has_residual and dx.dtype == x.dtype:
 | 
			
		||||
        dresidual_in = dx
 | 
			
		||||
    return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNormSwishGateFn(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx,
 | 
			
		||||
        x,
 | 
			
		||||
        o,
 | 
			
		||||
        weight,
 | 
			
		||||
        bias,
 | 
			
		||||
        residual=None,
 | 
			
		||||
        eps=1e-6,
 | 
			
		||||
        prenorm=False,
 | 
			
		||||
        residual_in_fp32=False,
 | 
			
		||||
        is_rms_norm=False,
 | 
			
		||||
    ):
 | 
			
		||||
        x_shape_og = x.shape
 | 
			
		||||
        o_shape_og = o.shape
 | 
			
		||||
        # reshape input data into 2D tensor
 | 
			
		||||
        x = x.reshape(-1, x.shape[-1])
 | 
			
		||||
        o = o.reshape(-1, o.shape[-1])
 | 
			
		||||
        if residual is not None:
 | 
			
		||||
            assert residual.shape == x_shape_og
 | 
			
		||||
            residual = residual.reshape(-1, residual.shape[-1])
 | 
			
		||||
        residual_dtype = (
 | 
			
		||||
            residual.dtype
 | 
			
		||||
            if residual is not None
 | 
			
		||||
            else (torch.float32 if residual_in_fp32 else None)
 | 
			
		||||
        )
 | 
			
		||||
        y, mean, rstd, residual_out = _layer_norm_fwd(
 | 
			
		||||
            x, o, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
 | 
			
		||||
        )
 | 
			
		||||
        ctx.save_for_backward(residual_out, o, weight, bias, mean, rstd)
 | 
			
		||||
        ctx.x_shape_og = x_shape_og
 | 
			
		||||
        ctx.o_shape_og = o_shape_og
 | 
			
		||||
        ctx.eps = eps
 | 
			
		||||
        ctx.is_rms_norm = is_rms_norm
 | 
			
		||||
        ctx.has_residual = residual is not None
 | 
			
		||||
        ctx.prenorm = prenorm
 | 
			
		||||
        ctx.x_dtype = x.dtype
 | 
			
		||||
        y = y.reshape(x_shape_og)
 | 
			
		||||
        return y if not prenorm else (y, residual_out.reshape(x_shape_og))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, dy, *args):
 | 
			
		||||
        x, o, weight, bias, mean, rstd = ctx.saved_tensors
 | 
			
		||||
        dy = dy.reshape(-1, dy.shape[-1])
 | 
			
		||||
        assert dy.shape == x.shape
 | 
			
		||||
        if ctx.prenorm:
 | 
			
		||||
            dresidual = args[0]
 | 
			
		||||
            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
 | 
			
		||||
            assert dresidual.shape == x.shape
 | 
			
		||||
        else:
 | 
			
		||||
            dresidual = None
 | 
			
		||||
        dx, do, dw, db, dresidual_in = _layer_norm_bwd(
 | 
			
		||||
            dy,
 | 
			
		||||
            x,
 | 
			
		||||
            o,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            ctx.eps,
 | 
			
		||||
            mean,
 | 
			
		||||
            rstd,
 | 
			
		||||
            dresidual,
 | 
			
		||||
            ctx.has_residual,
 | 
			
		||||
            ctx.is_rms_norm,
 | 
			
		||||
            x_dtype=ctx.x_dtype,
 | 
			
		||||
        )
 | 
			
		||||
        return (
 | 
			
		||||
            dx.reshape(ctx.x_shape_og),
 | 
			
		||||
            do.reshape(ctx.o_shape_og),
 | 
			
		||||
            dw,
 | 
			
		||||
            db,
 | 
			
		||||
            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNormSwishGateLinearFn(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx,
 | 
			
		||||
        x,
 | 
			
		||||
        o,
 | 
			
		||||
        norm_weight,
 | 
			
		||||
        norm_bias,
 | 
			
		||||
        linear_weight,
 | 
			
		||||
        linear_bias,
 | 
			
		||||
        residual=None,
 | 
			
		||||
        eps=1e-6,
 | 
			
		||||
        prenorm=False,
 | 
			
		||||
        residual_in_fp32=False,
 | 
			
		||||
        is_rms_norm=False,
 | 
			
		||||
    ):
 | 
			
		||||
        x_shape_og = x.shape
 | 
			
		||||
        o_shape_og = o.shape
 | 
			
		||||
        # reshape input data into 2D tensor
 | 
			
		||||
        x = x.reshape(-1, x.shape[-1])
 | 
			
		||||
        o = o.reshape(-1, o.shape[-1])
 | 
			
		||||
        if residual is not None:
 | 
			
		||||
            assert residual.shape == x_shape_og
 | 
			
		||||
            residual = residual.reshape(-1, residual.shape[-1])
 | 
			
		||||
        residual_dtype = (
 | 
			
		||||
            residual.dtype
 | 
			
		||||
            if residual is not None
 | 
			
		||||
            else (torch.float32 if residual_in_fp32 else None)
 | 
			
		||||
        )
 | 
			
		||||
        y, mean, rstd, residual_out = _layer_norm_fwd(
 | 
			
		||||
            x,
 | 
			
		||||
            o,
 | 
			
		||||
            norm_weight,
 | 
			
		||||
            norm_bias,
 | 
			
		||||
            eps,
 | 
			
		||||
            residual,
 | 
			
		||||
            residual_dtype=residual_dtype,
 | 
			
		||||
            is_rms_norm=is_rms_norm
 | 
			
		||||
        )
 | 
			
		||||
        y = y.reshape(x_shape_og)
 | 
			
		||||
        dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
 | 
			
		||||
        linear_weight = linear_weight.to(dtype)
 | 
			
		||||
        linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
 | 
			
		||||
        out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
 | 
			
		||||
        # We don't store y, will be recomputed in the backward pass to save memory
 | 
			
		||||
        ctx.save_for_backward(residual_out, o, norm_weight, norm_bias, linear_weight, mean, rstd)
 | 
			
		||||
        ctx.x_shape_og = x_shape_og
 | 
			
		||||
        ctx.o_shape_og = o_shape_og
 | 
			
		||||
        ctx.eps = eps
 | 
			
		||||
        ctx.is_rms_norm = is_rms_norm
 | 
			
		||||
        ctx.has_residual = residual is not None
 | 
			
		||||
        ctx.prenorm = prenorm
 | 
			
		||||
        ctx.x_dtype = x.dtype
 | 
			
		||||
        ctx.linear_bias_is_none = linear_bias is None
 | 
			
		||||
        return out if not prenorm else (out, residual_out.reshape(x_shape_og))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, dout, *args):
 | 
			
		||||
        x, o, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
 | 
			
		||||
        dout = dout.reshape(-1, dout.shape[-1])
 | 
			
		||||
        dy = F.linear(dout, linear_weight.t())
 | 
			
		||||
        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
 | 
			
		||||
        assert dy.shape == x.shape
 | 
			
		||||
        if ctx.prenorm:
 | 
			
		||||
            dresidual = args[0]
 | 
			
		||||
            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
 | 
			
		||||
            assert dresidual.shape == x.shape
 | 
			
		||||
        else:
 | 
			
		||||
            dresidual = None
 | 
			
		||||
        dx, do, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
 | 
			
		||||
            dy,
 | 
			
		||||
            x,
 | 
			
		||||
            o,
 | 
			
		||||
            norm_weight,
 | 
			
		||||
            norm_bias,
 | 
			
		||||
            ctx.eps,
 | 
			
		||||
            mean,
 | 
			
		||||
            rstd,
 | 
			
		||||
            dresidual=dresidual,
 | 
			
		||||
            has_residual=ctx.has_residual,
 | 
			
		||||
            is_rms_norm=ctx.is_rms_norm,
 | 
			
		||||
            x_dtype=ctx.x_dtype,
 | 
			
		||||
            recompute_output=True,
 | 
			
		||||
        )
 | 
			
		||||
        dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
 | 
			
		||||
        return (
 | 
			
		||||
            dx.reshape(ctx.x_shape_og),
 | 
			
		||||
            do.reshape(ctx.o_shape_og),
 | 
			
		||||
            dnorm_weight,
 | 
			
		||||
            dnorm_bias,
 | 
			
		||||
            dlinear_weight,
 | 
			
		||||
            dlinear_bias,
 | 
			
		||||
            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def layer_norm_swish_gate_fn(
 | 
			
		||||
    x,
 | 
			
		||||
    o,
 | 
			
		||||
    weight,
 | 
			
		||||
    bias,
 | 
			
		||||
    residual=None,
 | 
			
		||||
    prenorm=False,
 | 
			
		||||
    residual_in_fp32=False,
 | 
			
		||||
    eps=1e-6
 | 
			
		||||
):
 | 
			
		||||
    return LayerNormSwishGateFn.apply(
 | 
			
		||||
        x,
 | 
			
		||||
        o,
 | 
			
		||||
        weight,
 | 
			
		||||
        bias,
 | 
			
		||||
        residual,
 | 
			
		||||
        eps,
 | 
			
		||||
        prenorm,
 | 
			
		||||
        residual_in_fp32,
 | 
			
		||||
        False
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rms_norm_swish_gate_fn(
 | 
			
		||||
    x,
 | 
			
		||||
    o,
 | 
			
		||||
    weight,
 | 
			
		||||
    bias,
 | 
			
		||||
    residual=None,
 | 
			
		||||
    prenorm=False,
 | 
			
		||||
    residual_in_fp32=False,
 | 
			
		||||
    eps=1e-6
 | 
			
		||||
):
 | 
			
		||||
    return LayerNormSwishGateFn.apply(
 | 
			
		||||
        x,
 | 
			
		||||
        o,
 | 
			
		||||
        weight,
 | 
			
		||||
        bias,
 | 
			
		||||
        residual,
 | 
			
		||||
        eps,
 | 
			
		||||
        prenorm,
 | 
			
		||||
        residual_in_fp32,
 | 
			
		||||
        True
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def layer_norm_swish_gate_linear_fn(
 | 
			
		||||
    x,
 | 
			
		||||
    o,
 | 
			
		||||
    norm_weight,
 | 
			
		||||
    norm_bias,
 | 
			
		||||
    linear_weight,
 | 
			
		||||
    linear_bias,
 | 
			
		||||
    residual=None,
 | 
			
		||||
    prenorm=False,
 | 
			
		||||
    residual_in_fp32=False,
 | 
			
		||||
    eps=1e-6
 | 
			
		||||
):
 | 
			
		||||
    return LayerNormSwishGateLinearFn.apply(
 | 
			
		||||
        x,
 | 
			
		||||
        o,
 | 
			
		||||
        norm_weight,
 | 
			
		||||
        norm_bias,
 | 
			
		||||
        linear_weight,
 | 
			
		||||
        linear_bias,
 | 
			
		||||
        residual,
 | 
			
		||||
        eps,
 | 
			
		||||
        prenorm,
 | 
			
		||||
        residual_in_fp32,
 | 
			
		||||
        False
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rms_norm_swish_gate_linear_fn(
 | 
			
		||||
    x,
 | 
			
		||||
    o,
 | 
			
		||||
    norm_weight,
 | 
			
		||||
    norm_bias,
 | 
			
		||||
    linear_weight,
 | 
			
		||||
    linear_bias,
 | 
			
		||||
    residual=None,
 | 
			
		||||
    prenorm=False,
 | 
			
		||||
    residual_in_fp32=False,
 | 
			
		||||
    eps=1e-6
 | 
			
		||||
):
 | 
			
		||||
    return LayerNormSwishGateLinearFn.apply(
 | 
			
		||||
        x,
 | 
			
		||||
        o,
 | 
			
		||||
        norm_weight,
 | 
			
		||||
        norm_bias,
 | 
			
		||||
        linear_weight,
 | 
			
		||||
        linear_bias,
 | 
			
		||||
        residual,
 | 
			
		||||
        eps,
 | 
			
		||||
        prenorm,
 | 
			
		||||
        residual_in_fp32,
 | 
			
		||||
        True
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedLayerNormSwishGate(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size,
 | 
			
		||||
        elementwise_affine: bool = True,
 | 
			
		||||
        eps=1e-5
 | 
			
		||||
    ) -> FusedLayerNormSwishGate:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(hidden_size))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter("weight", None)
 | 
			
		||||
        self.register_parameter("bias", None)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.hidden_size}"
 | 
			
		||||
        if not self.elementwise_affine:
 | 
			
		||||
            s += f", elementwise_affine={self.elementwise_affine}"
 | 
			
		||||
        s += f", eps={self.eps}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
 | 
			
		||||
        return layer_norm_swish_gate_fn(
 | 
			
		||||
            x,
 | 
			
		||||
            o,
 | 
			
		||||
            self.weight,
 | 
			
		||||
            self.bias,
 | 
			
		||||
            residual=residual,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
            prenorm=prenorm,
 | 
			
		||||
            residual_in_fp32=residual_in_fp32
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedRMSNormSwishGate(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size,
 | 
			
		||||
        elementwise_affine: bool = True,
 | 
			
		||||
        eps=1e-5
 | 
			
		||||
    ) -> FusedRMSNormSwishGate:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(hidden_size))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter("weight", None)
 | 
			
		||||
        self.register_parameter("bias", None)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.hidden_size}"
 | 
			
		||||
        if not self.elementwise_affine:
 | 
			
		||||
            s += f", elementwise_affine={self.elementwise_affine}"
 | 
			
		||||
        s += f", eps={self.eps}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
 | 
			
		||||
        return rms_norm_swish_gate_fn(
 | 
			
		||||
            x,
 | 
			
		||||
            o,
 | 
			
		||||
            self.weight,
 | 
			
		||||
            self.bias,
 | 
			
		||||
            residual=residual,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
            prenorm=prenorm,
 | 
			
		||||
            residual_in_fp32=residual_in_fp32
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedLayerNormSwishGateLinear(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size,
 | 
			
		||||
        elementwise_affine: bool = True,
 | 
			
		||||
        eps=1e-5
 | 
			
		||||
    ) -> FusedLayerNormSwishGateLinear:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(hidden_size))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter("weight", None)
 | 
			
		||||
        self.register_parameter("bias", None)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.hidden_size}"
 | 
			
		||||
        if not self.elementwise_affine:
 | 
			
		||||
            s += f", elementwise_affine={self.elementwise_affine}"
 | 
			
		||||
        s += f", eps={self.eps}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
 | 
			
		||||
        return layer_norm_swish_gate_linear_fn(
 | 
			
		||||
            x,
 | 
			
		||||
            o,
 | 
			
		||||
            self.weight,
 | 
			
		||||
            self.bias,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            residual=residual,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
            prenorm=prenorm,
 | 
			
		||||
            residual_in_fp32=residual_in_fp32
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedRMSNormSwishGateLinear(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size,
 | 
			
		||||
        elementwise_affine: bool = True,
 | 
			
		||||
        eps=1e-5
 | 
			
		||||
    ) -> FusedRMSNormSwishGateLinear:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(hidden_size))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter("weight", None)
 | 
			
		||||
        self.register_parameter("bias", None)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.hidden_size}"
 | 
			
		||||
        if not self.elementwise_affine:
 | 
			
		||||
            s += f", elementwise_affine={self.elementwise_affine}"
 | 
			
		||||
        s += f", eps={self.eps}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
 | 
			
		||||
        return rms_norm_swish_gate_linear_fn(
 | 
			
		||||
            x,
 | 
			
		||||
            o,
 | 
			
		||||
            self.weight,
 | 
			
		||||
            self.bias,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            residual=residual,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
            prenorm=prenorm,
 | 
			
		||||
            residual_in_fp32=residual_in_fp32
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										216
									
								
								finetune/lora/v6/fla/modules/l2norm.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										216
									
								
								finetune/lora/v6/fla/modules/l2norm.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,216 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
import math
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch.cuda.amp import custom_fwd, custom_bwd
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["N"],
 | 
			
		||||
)
 | 
			
		||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
 | 
			
		||||
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _l2_norm_fwd_1pass_kernel(
 | 
			
		||||
    X,  # pointer to the input
 | 
			
		||||
    Y,  # pointer to the output
 | 
			
		||||
    stride_x_row,  # how much to increase the pointer when moving by 1 row
 | 
			
		||||
    N,  # number of columns in X
 | 
			
		||||
    eps,  # epsilon to avoid division by zero
 | 
			
		||||
    BLOCK_N: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    # Map the program id to the row of X and Y it should compute.
 | 
			
		||||
    row = tl.program_id(0)
 | 
			
		||||
    X += row * stride_x_row
 | 
			
		||||
    Y += row * stride_x_row
 | 
			
		||||
    # Compute mean and variance
 | 
			
		||||
    cols = tl.arange(0, BLOCK_N)
 | 
			
		||||
    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
 | 
			
		||||
    xbar = tl.where(cols < N, x, 0.0)
 | 
			
		||||
    var = tl.sum(xbar * xbar, axis=0) 
 | 
			
		||||
    rstd = 1 / tl.sqrt(var + eps)
 | 
			
		||||
    # tl.store(Rstd + row, rstd)
 | 
			
		||||
    # Normalize and apply linear transformation
 | 
			
		||||
    mask = cols < N
 | 
			
		||||
    y = x * rstd
 | 
			
		||||
    # Write output
 | 
			
		||||
    tl.store(Y + cols, y, mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["N"],
 | 
			
		||||
)
 | 
			
		||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
 | 
			
		||||
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
 | 
			
		||||
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
 | 
			
		||||
# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _l2_norm_bwd_kernel(
 | 
			
		||||
    X,  # pointer to the input
 | 
			
		||||
    # Y,  # pointer to the output to be recomputed
 | 
			
		||||
    DY,  # pointer to the output gradient
 | 
			
		||||
    DX,  # pointer to the input gradient
 | 
			
		||||
    stride_x_row,  # how much to increase the pointer when moving by 1 row
 | 
			
		||||
    N,  # number of columns in X
 | 
			
		||||
    eps,  # epsilon to avoid division by zero
 | 
			
		||||
    BLOCK_N: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    # Map the program id to the elements of X, DX, and DY it should compute.
 | 
			
		||||
    # Map the program id to the row of X and Y it should compute.
 | 
			
		||||
    row = tl.program_id(0)
 | 
			
		||||
    X += row * stride_x_row
 | 
			
		||||
    DX += row * stride_x_row
 | 
			
		||||
    DY += row * stride_x_row
 | 
			
		||||
 | 
			
		||||
    # Y += row * stride_y_row
 | 
			
		||||
    cols = tl.arange(0, BLOCK_N)
 | 
			
		||||
    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
 | 
			
		||||
    x = tl.where(cols < N, x, 0.0)
 | 
			
		||||
    var = tl.sum(x * x) 
 | 
			
		||||
    rstd = 1 / tl.sqrt(var + eps)
 | 
			
		||||
    # tl.store(Rstd + row, rstd)
 | 
			
		||||
    # Normalize and apply linear transformation
 | 
			
		||||
    mask = cols < N
 | 
			
		||||
    # y = x * rstd
 | 
			
		||||
    dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)
 | 
			
		||||
    dy = tl.where(cols < N, dy, 0.0)
 | 
			
		||||
    # dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x 
 | 
			
		||||
    dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
 | 
			
		||||
    tl.store(DX + cols, dx, mask=mask)
 | 
			
		||||
 | 
			
		||||
def _l2_norm_fwd(
 | 
			
		||||
    x, eps=1e-6
 | 
			
		||||
):
 | 
			
		||||
    x_shape_og = x.shape
 | 
			
		||||
    x = x.reshape(-1, x.shape[-1])
 | 
			
		||||
    if x.stride(-1) != 1:
 | 
			
		||||
        x = x.contiguous()
 | 
			
		||||
        M, N = x.shape
 | 
			
		||||
    assert x.stride(-1) == 1 
 | 
			
		||||
    # allocate output
 | 
			
		||||
    y = torch.empty_like(x)
 | 
			
		||||
    assert y.stride(-1) == 1
 | 
			
		||||
    N = x.shape[-1]
 | 
			
		||||
    M = x.shape[0]
 | 
			
		||||
    # rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
 | 
			
		||||
    # Less than 64KB per feature: enqueue fused kernel
 | 
			
		||||
    MAX_FUSED_SIZE = 65536 // x.element_size()
 | 
			
		||||
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
 | 
			
		||||
    if N > BLOCK_N:
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            "This layer norm doesn't support feature dim >= 64KB.")
 | 
			
		||||
    # heuristics for number of warps
 | 
			
		||||
    with torch.cuda.device(x.device.index):
 | 
			
		||||
        _l2_norm_fwd_1pass_kernel[(M,)](
 | 
			
		||||
            x,
 | 
			
		||||
            y,
 | 
			
		||||
            x.stride(0),
 | 
			
		||||
            N,
 | 
			
		||||
            eps,
 | 
			
		||||
            # is_rms_norm,
 | 
			
		||||
            BLOCK_N,
 | 
			
		||||
            # residual is not None,
 | 
			
		||||
            # residual_out is not None,
 | 
			
		||||
            # bias is not None,
 | 
			
		||||
        )
 | 
			
		||||
    return y.reshape(x_shape_og)
 | 
			
		||||
 | 
			
		||||
def _l2_norm_bwd(
 | 
			
		||||
    x, dy, eps=1e-5,
 | 
			
		||||
):
 | 
			
		||||
    x_shape_og = x.shape
 | 
			
		||||
    x = x.reshape(-1, dy.shape[-1])
 | 
			
		||||
    dy = dy.reshape(-1, dy.shape[-1])
 | 
			
		||||
    if dy.stride(-1) != 1:
 | 
			
		||||
        dy = dy.contiguous()
 | 
			
		||||
    assert dy.shape == x.shape
 | 
			
		||||
    # allocate output
 | 
			
		||||
    dx = torch.empty_like(x)
 | 
			
		||||
    N = x.shape[-1]
 | 
			
		||||
    M = x.shape[0]
 | 
			
		||||
    assert x.stride(-1) == 1
 | 
			
		||||
    assert dy.stride(-1) == 1
 | 
			
		||||
    # rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
 | 
			
		||||
    # Less than 64KB per feature: enqueue fused kernel
 | 
			
		||||
    MAX_FUSED_SIZE = 65536 // x.element_size()
 | 
			
		||||
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
 | 
			
		||||
    if N > BLOCK_N:
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            "This layer norm doesn't support feature dim >= 64KB.")
 | 
			
		||||
    # heuristics for number of warps
 | 
			
		||||
    with torch.cuda.device(x.device.index):
 | 
			
		||||
        _l2_norm_bwd_kernel[(M,)](
 | 
			
		||||
            x,
 | 
			
		||||
            dy,
 | 
			
		||||
            dx,
 | 
			
		||||
            x.stride(0),
 | 
			
		||||
            N,
 | 
			
		||||
            eps,
 | 
			
		||||
            BLOCK_N,
 | 
			
		||||
        )
 | 
			
		||||
    return dx.reshape(x_shape_og)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class L2NormFN(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx,
 | 
			
		||||
        x,
 | 
			
		||||
        eps=1e-6,
 | 
			
		||||
    ):
 | 
			
		||||
        # reshape input data into 2D tensor
 | 
			
		||||
        y = _l2_norm_fwd(x, eps)
 | 
			
		||||
        ctx.x_shape_og = x_shape_og
 | 
			
		||||
        ctx.eps = eps
 | 
			
		||||
        ctx.x_dtype = x.dtype
 | 
			
		||||
        ctx.save_for_backward(x)
 | 
			
		||||
        return y 
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, dy, *args):
 | 
			
		||||
        x, = ctx.saved_tensors
 | 
			
		||||
        dx = _l2_norm_bwd(
 | 
			
		||||
            x,
 | 
			
		||||
            dy,
 | 
			
		||||
            ctx.eps,
 | 
			
		||||
        )
 | 
			
		||||
        return (
 | 
			
		||||
            dx,
 | 
			
		||||
            None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
l2_norm_fn = L2NormFN.apply
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    x = torch.rand(10, 10, 100).cuda().requires_grad_(True)
 | 
			
		||||
    y = torch.nn.functional.normalize(x, dim=-1, p=2)
 | 
			
		||||
    dy = torch.rand_like(y)
 | 
			
		||||
    y.backward(dy, retain_graph=True)
 | 
			
		||||
    x_grad, x.grad = x.grad, None
 | 
			
		||||
    y2 = l2_norm_fn(x, 1e-6)
 | 
			
		||||
    print((y-y2).abs().max())
 | 
			
		||||
    y2.backward(dy, retain_graph=True)
 | 
			
		||||
    x_grad2, x.grad = x.grad, None
 | 
			
		||||
    print((x_grad2-x_grad).abs().max())
 | 
			
		||||
    breakpoint()    
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    
 | 
			
		||||
							
								
								
									
										802
									
								
								finetune/lora/v6/fla/modules/layernorm.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										802
									
								
								finetune/lora/v6/fla/modules/layernorm.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,802 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2023, Tri Dao.
 | 
			
		||||
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
 | 
			
		||||
# Implement residual + layer_norm / rms_norm.
 | 
			
		||||
 | 
			
		||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
 | 
			
		||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
 | 
			
		||||
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
 | 
			
		||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
 | 
			
		||||
    dtype = x.dtype
 | 
			
		||||
    if upcast:
 | 
			
		||||
        weight = weight.float()
 | 
			
		||||
        bias = bias.float() if bias is not None else None
 | 
			
		||||
    if upcast:
 | 
			
		||||
        x = x.float()
 | 
			
		||||
        residual = residual.float() if residual is not None else residual
 | 
			
		||||
    if residual is not None:
 | 
			
		||||
        x = (x + residual).to(x.dtype)
 | 
			
		||||
    out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
 | 
			
		||||
        dtype
 | 
			
		||||
    )
 | 
			
		||||
    return out if not prenorm else (out, x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
 | 
			
		||||
    dtype = x.dtype
 | 
			
		||||
    if upcast:
 | 
			
		||||
        weight = weight.float()
 | 
			
		||||
        bias = bias.float() if bias is not None else None
 | 
			
		||||
    if upcast:
 | 
			
		||||
        x = x.float()
 | 
			
		||||
        residual = residual.float() if residual is not None else residual
 | 
			
		||||
    if residual is not None:
 | 
			
		||||
        x = (x + residual).to(x.dtype)
 | 
			
		||||
    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
 | 
			
		||||
    out = (x * rstd * weight) + \
 | 
			
		||||
        bias if bias is not None else (x * rstd * weight)
 | 
			
		||||
    out = out.to(dtype)
 | 
			
		||||
    return out if not prenorm else (out, x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
 | 
			
		||||
)
 | 
			
		||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
 | 
			
		||||
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _layer_norm_fwd_1pass_kernel(
 | 
			
		||||
    X,  # pointer to the input
 | 
			
		||||
    Y,  # pointer to the output
 | 
			
		||||
    W,  # pointer to the weights
 | 
			
		||||
    B,  # pointer to the biases
 | 
			
		||||
    RESIDUAL,  # pointer to the residual
 | 
			
		||||
    RESIDUAL_OUT,  # pointer to the residual
 | 
			
		||||
    Mean,  # pointer to the mean
 | 
			
		||||
    Rstd,  # pointer to the 1/std
 | 
			
		||||
    stride_x_row,  # how much to increase the pointer when moving by 1 row
 | 
			
		||||
    stride_y_row,
 | 
			
		||||
    stride_res_row,
 | 
			
		||||
    stride_res_out_row,
 | 
			
		||||
    N,  # number of columns in X
 | 
			
		||||
    eps,  # epsilon to avoid division by zero
 | 
			
		||||
    IS_RMS_NORM: tl.constexpr,
 | 
			
		||||
    BLOCK_N: tl.constexpr,
 | 
			
		||||
    HAS_RESIDUAL: tl.constexpr,
 | 
			
		||||
    STORE_RESIDUAL_OUT: tl.constexpr,
 | 
			
		||||
    HAS_WEIGHT: tl.constexpr,
 | 
			
		||||
    HAS_BIAS: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    # Map the program id to the row of X and Y it should compute.
 | 
			
		||||
    row = tl.program_id(0)
 | 
			
		||||
    X += row * stride_x_row
 | 
			
		||||
    Y += row * stride_y_row
 | 
			
		||||
    if HAS_RESIDUAL:
 | 
			
		||||
        RESIDUAL += row * stride_res_row
 | 
			
		||||
    if STORE_RESIDUAL_OUT:
 | 
			
		||||
        RESIDUAL_OUT += row * stride_res_out_row
 | 
			
		||||
    # Compute mean and variance
 | 
			
		||||
    cols = tl.arange(0, BLOCK_N)
 | 
			
		||||
    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
 | 
			
		||||
    if HAS_RESIDUAL:
 | 
			
		||||
        residual = tl.load(RESIDUAL + cols, mask=cols <
 | 
			
		||||
                           N, other=0.0).to(tl.float32)
 | 
			
		||||
        x += residual
 | 
			
		||||
    if STORE_RESIDUAL_OUT:
 | 
			
		||||
        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
 | 
			
		||||
    if not IS_RMS_NORM:
 | 
			
		||||
        mean = tl.sum(x, axis=0) / N
 | 
			
		||||
        tl.store(Mean + row, mean)
 | 
			
		||||
        xbar = tl.where(cols < N, x - mean, 0.0)
 | 
			
		||||
        var = tl.sum(xbar * xbar, axis=0) / N
 | 
			
		||||
    else:
 | 
			
		||||
        xbar = tl.where(cols < N, x, 0.0)
 | 
			
		||||
        var = tl.sum(xbar * xbar, axis=0) / N
 | 
			
		||||
    rstd = 1 / tl.sqrt(var + eps)
 | 
			
		||||
    tl.store(Rstd + row, rstd)
 | 
			
		||||
    # Normalize and apply linear transformation
 | 
			
		||||
    mask = cols < N
 | 
			
		||||
    if HAS_WEIGHT:
 | 
			
		||||
        w = tl.load(W + cols, mask=mask).to(tl.float32)
 | 
			
		||||
    if HAS_BIAS:
 | 
			
		||||
        b = tl.load(B + cols, mask=mask).to(tl.float32)
 | 
			
		||||
    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
 | 
			
		||||
 | 
			
		||||
    y = x_hat * w if HAS_WEIGHT else x_hat
 | 
			
		||||
    if HAS_BIAS:
 | 
			
		||||
        y = y + b
 | 
			
		||||
    # Write output
 | 
			
		||||
    tl.store(Y + cols, y, mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _layer_norm_fwd(
 | 
			
		||||
    x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
 | 
			
		||||
):
 | 
			
		||||
    if residual is not None:
 | 
			
		||||
        residual_dtype = residual.dtype
 | 
			
		||||
    M, N = x.shape
 | 
			
		||||
    assert x.stride(-1) == 1
 | 
			
		||||
    if residual is not None:
 | 
			
		||||
        assert residual.stride(-1) == 1
 | 
			
		||||
        assert residual.shape == (M, N)
 | 
			
		||||
    if weight is not None:
 | 
			
		||||
        assert weight.shape == (N,)
 | 
			
		||||
        assert weight.stride(-1) == 1
 | 
			
		||||
    if bias is not None:
 | 
			
		||||
        assert bias.stride(-1) == 1
 | 
			
		||||
        assert bias.shape == (N,)
 | 
			
		||||
    # allocate output
 | 
			
		||||
    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
 | 
			
		||||
    assert y.stride(-1) == 1
 | 
			
		||||
    if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
 | 
			
		||||
        residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
 | 
			
		||||
        assert residual_out.stride(-1) == 1
 | 
			
		||||
    else:
 | 
			
		||||
        residual_out = None
 | 
			
		||||
    mean = torch.empty((M,), dtype=torch.float32,
 | 
			
		||||
                       device="cuda") if not is_rms_norm else None
 | 
			
		||||
    rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
 | 
			
		||||
    # Less than 64KB per feature: enqueue fused kernel
 | 
			
		||||
    MAX_FUSED_SIZE = 65536 // x.element_size()
 | 
			
		||||
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
 | 
			
		||||
    if N > BLOCK_N:
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            "This layer norm doesn't support feature dim >= 64KB.")
 | 
			
		||||
    # heuristics for number of warps
 | 
			
		||||
    with torch.cuda.device(x.device.index):
 | 
			
		||||
        _layer_norm_fwd_1pass_kernel[(M,)](
 | 
			
		||||
            x,
 | 
			
		||||
            y,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            residual,
 | 
			
		||||
            residual_out,
 | 
			
		||||
            mean,
 | 
			
		||||
            rstd,
 | 
			
		||||
            x.stride(0),
 | 
			
		||||
            y.stride(0),
 | 
			
		||||
            residual.stride(0) if residual is not None else 0,
 | 
			
		||||
            residual_out.stride(0) if residual_out is not None else 0,
 | 
			
		||||
            N,
 | 
			
		||||
            eps,
 | 
			
		||||
            is_rms_norm,
 | 
			
		||||
            BLOCK_N,
 | 
			
		||||
            residual is not None,
 | 
			
		||||
            residual_out is not None,
 | 
			
		||||
            weight is not None,
 | 
			
		||||
            bias is not None,
 | 
			
		||||
        )
 | 
			
		||||
    # residual_out is None if residual is None and residual_dtype == input_dtype
 | 
			
		||||
    return y, mean, rstd, residual_out if residual_out is not None else x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
 | 
			
		||||
)
 | 
			
		||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
 | 
			
		||||
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
 | 
			
		||||
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
 | 
			
		||||
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _layer_norm_bwd_kernel(
 | 
			
		||||
    X,  # pointer to the input
 | 
			
		||||
    W,  # pointer to the weights
 | 
			
		||||
    B,  # pointer to the biases
 | 
			
		||||
    Y,  # pointer to the output to be recomputed
 | 
			
		||||
    DY,  # pointer to the output gradient
 | 
			
		||||
    DX,  # pointer to the input gradient
 | 
			
		||||
    DW,  # pointer to the partial sum of weights gradient
 | 
			
		||||
    DB,  # pointer to the partial sum of biases gradient
 | 
			
		||||
    DRESIDUAL,
 | 
			
		||||
    DRESIDUAL_IN,
 | 
			
		||||
    Mean,  # pointer to the mean
 | 
			
		||||
    Rstd,  # pointer to the 1/std
 | 
			
		||||
    stride_x_row,  # how much to increase the pointer when moving by 1 row
 | 
			
		||||
    stride_y_row,
 | 
			
		||||
    stride_dy_row,
 | 
			
		||||
    stride_dx_row,
 | 
			
		||||
    stride_dres_row,
 | 
			
		||||
    stride_dres_in_row,
 | 
			
		||||
    M,  # number of rows in X
 | 
			
		||||
    N,  # number of columns in X
 | 
			
		||||
    eps,  # epsilon to avoid division by zero
 | 
			
		||||
    rows_per_program,
 | 
			
		||||
    IS_RMS_NORM: tl.constexpr,
 | 
			
		||||
    BLOCK_N: tl.constexpr,
 | 
			
		||||
    HAS_DRESIDUAL: tl.constexpr,
 | 
			
		||||
    STORE_DRESIDUAL: tl.constexpr,
 | 
			
		||||
    HAS_WEIGHT: tl.constexpr,
 | 
			
		||||
    HAS_BIAS: tl.constexpr,
 | 
			
		||||
    RECOMPUTE_OUTPUT: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    # Map the program id to the elements of X, DX, and DY it should compute.
 | 
			
		||||
    row_block_id = tl.program_id(0)
 | 
			
		||||
    row_start = row_block_id * rows_per_program
 | 
			
		||||
    cols = tl.arange(0, BLOCK_N)
 | 
			
		||||
    mask = cols < N
 | 
			
		||||
    X += row_start * stride_x_row
 | 
			
		||||
    if HAS_DRESIDUAL:
 | 
			
		||||
        DRESIDUAL += row_start * stride_dres_row
 | 
			
		||||
    if STORE_DRESIDUAL:
 | 
			
		||||
        DRESIDUAL_IN += row_start * stride_dres_in_row
 | 
			
		||||
    DY += row_start * stride_dy_row
 | 
			
		||||
    DX += row_start * stride_dx_row
 | 
			
		||||
    if RECOMPUTE_OUTPUT:
 | 
			
		||||
        Y += row_start * stride_y_row
 | 
			
		||||
    if HAS_WEIGHT:
 | 
			
		||||
        w = tl.load(W + cols, mask=mask).to(tl.float32)
 | 
			
		||||
        dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
 | 
			
		||||
    if RECOMPUTE_OUTPUT and HAS_BIAS:
 | 
			
		||||
        b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
 | 
			
		||||
    if HAS_BIAS:
 | 
			
		||||
        db = tl.zeros((BLOCK_N,), dtype=tl.float32)
 | 
			
		||||
    row_end = min((row_block_id + 1) * rows_per_program, M)
 | 
			
		||||
    for row in range(row_start, row_end):
 | 
			
		||||
        # Load data to SRAM
 | 
			
		||||
        x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        if not IS_RMS_NORM:
 | 
			
		||||
            mean = tl.load(Mean + row)
 | 
			
		||||
        rstd = tl.load(Rstd + row)
 | 
			
		||||
        # Compute dx
 | 
			
		||||
        xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
 | 
			
		||||
        xhat = tl.where(mask, xhat, 0.0)
 | 
			
		||||
        if RECOMPUTE_OUTPUT:
 | 
			
		||||
            y = xhat * w if HAS_WEIGHT else xhat
 | 
			
		||||
            if HAS_BIAS:
 | 
			
		||||
                y = y + b
 | 
			
		||||
            tl.store(Y + cols, y, mask=mask)
 | 
			
		||||
        wdy = dy
 | 
			
		||||
        if HAS_WEIGHT:
 | 
			
		||||
            wdy = dy * w
 | 
			
		||||
            dw += dy * xhat
 | 
			
		||||
        if HAS_BIAS:
 | 
			
		||||
            db += dy
 | 
			
		||||
        if not IS_RMS_NORM:
 | 
			
		||||
            c1 = tl.sum(xhat * wdy, axis=0) / N
 | 
			
		||||
            c2 = tl.sum(wdy, axis=0) / N
 | 
			
		||||
            dx = (wdy - (xhat * c1 + c2)) * rstd
 | 
			
		||||
        else:
 | 
			
		||||
            c1 = tl.sum(xhat * wdy, axis=0) / N
 | 
			
		||||
            dx = (wdy - xhat * c1) * rstd
 | 
			
		||||
        if HAS_DRESIDUAL:
 | 
			
		||||
            dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
            dx += dres
 | 
			
		||||
        # Write dx
 | 
			
		||||
        if STORE_DRESIDUAL:
 | 
			
		||||
            tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
 | 
			
		||||
        tl.store(DX + cols, dx, mask=mask)
 | 
			
		||||
 | 
			
		||||
        X += stride_x_row
 | 
			
		||||
        if HAS_DRESIDUAL:
 | 
			
		||||
            DRESIDUAL += stride_dres_row
 | 
			
		||||
        if STORE_DRESIDUAL:
 | 
			
		||||
            DRESIDUAL_IN += stride_dres_in_row
 | 
			
		||||
        if RECOMPUTE_OUTPUT:
 | 
			
		||||
            Y += stride_y_row
 | 
			
		||||
        DY += stride_dy_row
 | 
			
		||||
        DX += stride_dx_row
 | 
			
		||||
    if HAS_WEIGHT:
 | 
			
		||||
        tl.store(DW + row_block_id * N + cols, dw, mask=mask)
 | 
			
		||||
    if HAS_BIAS:
 | 
			
		||||
        tl.store(DB + row_block_id * N + cols, db, mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _layer_norm_bwd(
 | 
			
		||||
    dy,
 | 
			
		||||
    x,
 | 
			
		||||
    weight,
 | 
			
		||||
    bias,
 | 
			
		||||
    eps,
 | 
			
		||||
    mean,
 | 
			
		||||
    rstd,
 | 
			
		||||
    dresidual=None,
 | 
			
		||||
    has_residual=False,
 | 
			
		||||
    is_rms_norm=False,
 | 
			
		||||
    x_dtype=None,
 | 
			
		||||
    recompute_output=False,
 | 
			
		||||
):
 | 
			
		||||
    M, N = x.shape
 | 
			
		||||
    assert x.stride(-1) == 1
 | 
			
		||||
    assert dy.stride(-1) == 1
 | 
			
		||||
    assert dy.shape == (M, N)
 | 
			
		||||
    if dresidual is not None:
 | 
			
		||||
        assert dresidual.stride(-1) == 1
 | 
			
		||||
        assert dresidual.shape == (M, N)
 | 
			
		||||
    if weight is not None:
 | 
			
		||||
        assert weight.shape == (N,)
 | 
			
		||||
        assert weight.stride(-1) == 1
 | 
			
		||||
    if bias is not None:
 | 
			
		||||
        assert bias.stride(-1) == 1
 | 
			
		||||
        assert bias.shape == (N,)
 | 
			
		||||
    # allocate output
 | 
			
		||||
    dx = (
 | 
			
		||||
        torch.empty_like(x)
 | 
			
		||||
        if x_dtype is None
 | 
			
		||||
        else torch.empty(M, N, dtype=x_dtype, device=x.device)
 | 
			
		||||
    )
 | 
			
		||||
    dresidual_in = torch.empty_like(
 | 
			
		||||
        x) if has_residual and dx.dtype != x.dtype else None
 | 
			
		||||
    y = torch.empty(M, N, dtype=dy.dtype,
 | 
			
		||||
                    device=dy.device) if recompute_output else None
 | 
			
		||||
 | 
			
		||||
    # Less than 64KB per feature: enqueue fused kernel
 | 
			
		||||
    MAX_FUSED_SIZE = 65536 // x.element_size()
 | 
			
		||||
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
 | 
			
		||||
    if N > BLOCK_N:
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            "This layer norm doesn't support feature dim >= 64KB.")
 | 
			
		||||
    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
 | 
			
		||||
    _dw = (
 | 
			
		||||
        torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
 | 
			
		||||
        if weight is not None
 | 
			
		||||
        else None
 | 
			
		||||
    )
 | 
			
		||||
    _db = (
 | 
			
		||||
        torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
 | 
			
		||||
        if bias is not None
 | 
			
		||||
        else None
 | 
			
		||||
    )
 | 
			
		||||
    rows_per_program = math.ceil(M / sm_count)
 | 
			
		||||
    grid = (sm_count,)
 | 
			
		||||
    with torch.cuda.device(x.device.index):
 | 
			
		||||
        _layer_norm_bwd_kernel[grid](
 | 
			
		||||
            x,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            y,
 | 
			
		||||
            dy,
 | 
			
		||||
            dx,
 | 
			
		||||
            _dw,
 | 
			
		||||
            _db,
 | 
			
		||||
            dresidual,
 | 
			
		||||
            dresidual_in,
 | 
			
		||||
            mean,
 | 
			
		||||
            rstd,
 | 
			
		||||
            x.stride(0),
 | 
			
		||||
            0 if not recompute_output else y.stride(0),
 | 
			
		||||
            dy.stride(0),
 | 
			
		||||
            dx.stride(0),
 | 
			
		||||
            dresidual.stride(0) if dresidual is not None else 0,
 | 
			
		||||
            dresidual_in.stride(0) if dresidual_in is not None else 0,
 | 
			
		||||
            M,
 | 
			
		||||
            N,
 | 
			
		||||
            eps,
 | 
			
		||||
            rows_per_program,
 | 
			
		||||
            is_rms_norm,
 | 
			
		||||
            BLOCK_N,
 | 
			
		||||
            dresidual is not None,
 | 
			
		||||
            dresidual_in is not None,
 | 
			
		||||
            weight is not None,
 | 
			
		||||
            bias is not None,
 | 
			
		||||
        )
 | 
			
		||||
    dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
 | 
			
		||||
    db = _db.sum(0).to(bias.dtype) if bias is not None else None
 | 
			
		||||
    # Don't need to compute dresidual_in separately in this case
 | 
			
		||||
    if has_residual and dx.dtype == x.dtype:
 | 
			
		||||
        dresidual_in = dx
 | 
			
		||||
    return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNormFn(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx,
 | 
			
		||||
        x,
 | 
			
		||||
        weight,
 | 
			
		||||
        bias,
 | 
			
		||||
        residual=None,
 | 
			
		||||
        eps=1e-6,
 | 
			
		||||
        prenorm=False,
 | 
			
		||||
        residual_in_fp32=False,
 | 
			
		||||
        is_rms_norm=False,
 | 
			
		||||
    ):
 | 
			
		||||
        x_shape_og = x.shape
 | 
			
		||||
        # reshape input data into 2D tensor
 | 
			
		||||
        x = x.reshape(-1, x.shape[-1])
 | 
			
		||||
        if residual is not None:
 | 
			
		||||
            assert residual.shape == x_shape_og
 | 
			
		||||
            residual = residual.reshape(-1, residual.shape[-1])
 | 
			
		||||
        residual_dtype = (
 | 
			
		||||
            residual.dtype
 | 
			
		||||
            if residual is not None
 | 
			
		||||
            else (torch.float32 if residual_in_fp32 else None)
 | 
			
		||||
        )
 | 
			
		||||
        y, mean, rstd, residual_out = _layer_norm_fwd(
 | 
			
		||||
            x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
 | 
			
		||||
        )
 | 
			
		||||
        ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
 | 
			
		||||
        ctx.x_shape_og = x_shape_og
 | 
			
		||||
        ctx.eps = eps
 | 
			
		||||
        ctx.is_rms_norm = is_rms_norm
 | 
			
		||||
        ctx.has_residual = residual is not None
 | 
			
		||||
        ctx.prenorm = prenorm
 | 
			
		||||
        ctx.x_dtype = x.dtype
 | 
			
		||||
        y = y.reshape(x_shape_og)
 | 
			
		||||
        return y if not prenorm else (y, residual_out.reshape(x_shape_og))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, dy, *args):
 | 
			
		||||
        x, weight, bias, mean, rstd = ctx.saved_tensors
 | 
			
		||||
        dy = dy.reshape(-1, dy.shape[-1])
 | 
			
		||||
        assert dy.shape == x.shape
 | 
			
		||||
        if ctx.prenorm:
 | 
			
		||||
            dresidual = args[0]
 | 
			
		||||
            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
 | 
			
		||||
            assert dresidual.shape == x.shape
 | 
			
		||||
        else:
 | 
			
		||||
            dresidual = None
 | 
			
		||||
        dx, dw, db, dresidual_in = _layer_norm_bwd(
 | 
			
		||||
            dy,
 | 
			
		||||
            x,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            ctx.eps,
 | 
			
		||||
            mean,
 | 
			
		||||
            rstd,
 | 
			
		||||
            dresidual,
 | 
			
		||||
            ctx.has_residual,
 | 
			
		||||
            ctx.is_rms_norm,
 | 
			
		||||
            x_dtype=ctx.x_dtype,
 | 
			
		||||
        )
 | 
			
		||||
        return (
 | 
			
		||||
            dx.reshape(ctx.x_shape_og),
 | 
			
		||||
            dw,
 | 
			
		||||
            db,
 | 
			
		||||
            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def layer_norm_fn(
 | 
			
		||||
    x,
 | 
			
		||||
    weight,
 | 
			
		||||
    bias,
 | 
			
		||||
    residual=None,
 | 
			
		||||
    eps=1e-6,
 | 
			
		||||
    prenorm=False,
 | 
			
		||||
    residual_in_fp32=False,
 | 
			
		||||
    is_rms_norm=False,
 | 
			
		||||
):
 | 
			
		||||
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rms_norm_fn(
 | 
			
		||||
    x,
 | 
			
		||||
    weight,
 | 
			
		||||
    bias,
 | 
			
		||||
    residual=None,
 | 
			
		||||
    prenorm=False,
 | 
			
		||||
    residual_in_fp32=False,
 | 
			
		||||
    eps=1e-6
 | 
			
		||||
):
 | 
			
		||||
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNorm(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        elementwise_affine: bool = True,
 | 
			
		||||
        eps: float = 1e-5
 | 
			
		||||
    ) -> LayerNorm:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(hidden_size))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter("weight", None)
 | 
			
		||||
        self.register_parameter("bias", None)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.hidden_size}"
 | 
			
		||||
        if not self.elementwise_affine:
 | 
			
		||||
            s += f", elementwise_affine={self.elementwise_affine}"
 | 
			
		||||
        s += f", eps={self.eps}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
 | 
			
		||||
        return layer_norm_fn(
 | 
			
		||||
            x,
 | 
			
		||||
            self.weight,
 | 
			
		||||
            self.bias,
 | 
			
		||||
            residual=residual,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
            prenorm=prenorm,
 | 
			
		||||
            residual_in_fp32=residual_in_fp32
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RMSNorm(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        elementwise_affine: bool = True,
 | 
			
		||||
        eps: float = 1e-5
 | 
			
		||||
    ) -> RMSNorm:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(hidden_size))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter("weight", None)
 | 
			
		||||
        self.register_parameter("bias", None)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.hidden_size}"
 | 
			
		||||
        if not self.elementwise_affine:
 | 
			
		||||
            s += f", elementwise_affine={self.elementwise_affine}"
 | 
			
		||||
        s += f", eps={self.eps}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
 | 
			
		||||
        return rms_norm_fn(
 | 
			
		||||
            x,
 | 
			
		||||
            self.weight,
 | 
			
		||||
            self.bias,
 | 
			
		||||
            residual=residual,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
            prenorm=prenorm,
 | 
			
		||||
            residual_in_fp32=residual_in_fp32,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNormLinearFn(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx,
 | 
			
		||||
        x,
 | 
			
		||||
        norm_weight,
 | 
			
		||||
        norm_bias,
 | 
			
		||||
        linear_weight,
 | 
			
		||||
        linear_bias,
 | 
			
		||||
        residual=None,
 | 
			
		||||
        eps=1e-6,
 | 
			
		||||
        prenorm=False,
 | 
			
		||||
        residual_in_fp32=False,
 | 
			
		||||
        is_rms_norm=False,
 | 
			
		||||
    ):
 | 
			
		||||
        x_shape_og = x.shape
 | 
			
		||||
        # reshape input data into 2D tensor
 | 
			
		||||
        x = x.reshape(-1, x.shape[-1])
 | 
			
		||||
        if residual is not None:
 | 
			
		||||
            assert residual.shape == x_shape_og
 | 
			
		||||
            residual = residual.reshape(-1, residual.shape[-1])
 | 
			
		||||
        residual_dtype = (
 | 
			
		||||
            residual.dtype
 | 
			
		||||
            if residual is not None
 | 
			
		||||
            else (torch.float32 if residual_in_fp32 else None)
 | 
			
		||||
        )
 | 
			
		||||
        y, mean, rstd, residual_out = _layer_norm_fwd(
 | 
			
		||||
            x,
 | 
			
		||||
            norm_weight,
 | 
			
		||||
            norm_bias,
 | 
			
		||||
            eps,
 | 
			
		||||
            residual,
 | 
			
		||||
            out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
 | 
			
		||||
            residual_dtype=residual_dtype,
 | 
			
		||||
            is_rms_norm=is_rms_norm,
 | 
			
		||||
        )
 | 
			
		||||
        y = y.reshape(x_shape_og)
 | 
			
		||||
        dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
 | 
			
		||||
        linear_weight = linear_weight.to(dtype)
 | 
			
		||||
        linear_bias = linear_bias.to(
 | 
			
		||||
            dtype) if linear_bias is not None else None
 | 
			
		||||
        out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
 | 
			
		||||
        # We don't store y, will be recomputed in the backward pass to save memory
 | 
			
		||||
        ctx.save_for_backward(residual_out, norm_weight,
 | 
			
		||||
                              norm_bias, linear_weight, mean, rstd)
 | 
			
		||||
        ctx.x_shape_og = x_shape_og
 | 
			
		||||
        ctx.eps = eps
 | 
			
		||||
        ctx.is_rms_norm = is_rms_norm
 | 
			
		||||
        ctx.has_residual = residual is not None
 | 
			
		||||
        ctx.prenorm = prenorm
 | 
			
		||||
        ctx.x_dtype = x.dtype
 | 
			
		||||
        ctx.linear_bias_is_none = linear_bias is None
 | 
			
		||||
        return out if not prenorm else (out, residual_out.reshape(x_shape_og))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, dout, *args):
 | 
			
		||||
        x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
 | 
			
		||||
        dout = dout.reshape(-1, dout.shape[-1])
 | 
			
		||||
        dy = F.linear(dout, linear_weight.t())
 | 
			
		||||
        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
 | 
			
		||||
        assert dy.shape == x.shape
 | 
			
		||||
        if ctx.prenorm:
 | 
			
		||||
            dresidual = args[0]
 | 
			
		||||
            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
 | 
			
		||||
            assert dresidual.shape == x.shape
 | 
			
		||||
        else:
 | 
			
		||||
            dresidual = None
 | 
			
		||||
        dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
 | 
			
		||||
            dy,
 | 
			
		||||
            x,
 | 
			
		||||
            norm_weight,
 | 
			
		||||
            norm_bias,
 | 
			
		||||
            ctx.eps,
 | 
			
		||||
            mean,
 | 
			
		||||
            rstd,
 | 
			
		||||
            dresidual,
 | 
			
		||||
            ctx.has_residual,
 | 
			
		||||
            ctx.is_rms_norm,
 | 
			
		||||
            x_dtype=ctx.x_dtype,
 | 
			
		||||
            recompute_output=True,
 | 
			
		||||
        )
 | 
			
		||||
        dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
 | 
			
		||||
        return (
 | 
			
		||||
            dx.reshape(ctx.x_shape_og),
 | 
			
		||||
            dnorm_weight,
 | 
			
		||||
            dnorm_bias,
 | 
			
		||||
            dlinear_weight,
 | 
			
		||||
            dlinear_bias,
 | 
			
		||||
            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
            None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def layer_norm_linear_fn(
 | 
			
		||||
    x,
 | 
			
		||||
    norm_weight,
 | 
			
		||||
    norm_bias,
 | 
			
		||||
    linear_weight,
 | 
			
		||||
    linear_bias,
 | 
			
		||||
    residual=None,
 | 
			
		||||
    eps=1e-6,
 | 
			
		||||
    prenorm=False,
 | 
			
		||||
    residual_in_fp32=False,
 | 
			
		||||
    is_rms_norm=False,
 | 
			
		||||
):
 | 
			
		||||
    return LayerNormLinearFn.apply(
 | 
			
		||||
        x,
 | 
			
		||||
        norm_weight,
 | 
			
		||||
        norm_bias,
 | 
			
		||||
        linear_weight,
 | 
			
		||||
        linear_bias,
 | 
			
		||||
        residual,
 | 
			
		||||
        eps,
 | 
			
		||||
        prenorm,
 | 
			
		||||
        residual_in_fp32,
 | 
			
		||||
        is_rms_norm,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNormLinear(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size,
 | 
			
		||||
        elementwise_affine: bool = True,
 | 
			
		||||
        eps=1e-5
 | 
			
		||||
    ) -> LayerNormLinear:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(hidden_size))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter("weight", None)
 | 
			
		||||
        self.register_parameter("bias", None)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.hidden_size}"
 | 
			
		||||
        if not self.elementwise_affine:
 | 
			
		||||
            s += f", elementwise_affine={self.elementwise_affine}"
 | 
			
		||||
        s += f", eps={self.eps}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
 | 
			
		||||
        return layer_norm_linear_fn(
 | 
			
		||||
            x,
 | 
			
		||||
            self.weight,
 | 
			
		||||
            self.bias,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            residual=residual,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
            prenorm=prenorm,
 | 
			
		||||
            residual_in_fp32=residual_in_fp32,
 | 
			
		||||
            is_rms_norm=False
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RMSNormLinear(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size,
 | 
			
		||||
        elementwise_affine: bool = True,
 | 
			
		||||
        eps=1e-5
 | 
			
		||||
    ) -> RMSNormLinear:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.elementwise_affine = elementwise_affine
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(hidden_size))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter("weight", None)
 | 
			
		||||
        self.register_parameter("bias", None)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        s = f"{self.__class__.__name__}({self.hidden_size}"
 | 
			
		||||
        if not self.elementwise_affine:
 | 
			
		||||
            s += f", elementwise_affine={self.elementwise_affine}"
 | 
			
		||||
        s += f", eps={self.eps}"
 | 
			
		||||
        s += ")"
 | 
			
		||||
        return s
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
 | 
			
		||||
        return layer_norm_linear_fn(
 | 
			
		||||
            x,
 | 
			
		||||
            self.weight,
 | 
			
		||||
            self.bias,
 | 
			
		||||
            weight,
 | 
			
		||||
            bias,
 | 
			
		||||
            residual=residual,
 | 
			
		||||
            eps=self.eps,
 | 
			
		||||
            prenorm=prenorm,
 | 
			
		||||
            residual_in_fp32=residual_in_fp32,
 | 
			
		||||
            is_rms_norm=True
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										310
									
								
								finetune/lora/v6/fla/modules/rotary.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										310
									
								
								finetune/lora/v6/fla/modules/rotary.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,310 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2023, Tri Dao.
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from einops import rearrange, repeat
 | 
			
		||||
 | 
			
		||||
from fla.ops.rotary import apply_rotary
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rotate_half(x, interleaved=False):
 | 
			
		||||
    if not interleaved:
 | 
			
		||||
        x1, x2 = x.chunk(2, dim=-1)
 | 
			
		||||
        return torch.cat((-x2, x1), dim=-1)
 | 
			
		||||
    else:
 | 
			
		||||
        x1, x2 = x[..., ::2], x[..., 1::2]
 | 
			
		||||
        return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
 | 
			
		||||
    """
 | 
			
		||||
    x: (batch_size, seqlen, nheads, headdim)
 | 
			
		||||
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
 | 
			
		||||
    """
 | 
			
		||||
    ro_dim = cos.shape[-1] * 2
 | 
			
		||||
    assert ro_dim <= x.shape[-1]
 | 
			
		||||
    cos = repeat(
 | 
			
		||||
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
 | 
			
		||||
    sin = repeat(
 | 
			
		||||
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
 | 
			
		||||
    return torch.cat(
 | 
			
		||||
        [x[..., :ro_dim] * cos +
 | 
			
		||||
            rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
 | 
			
		||||
        dim=-1,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ApplyRotaryEmb(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx,
 | 
			
		||||
        x,
 | 
			
		||||
        cos,
 | 
			
		||||
        sin,
 | 
			
		||||
        interleaved=False,
 | 
			
		||||
        inplace=False,
 | 
			
		||||
        seqlen_offsets: Union[int, torch.Tensor] = 0,
 | 
			
		||||
        cu_seqlens: Optional[torch.Tensor] = None,
 | 
			
		||||
        max_seqlen: Optional[int] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        out = apply_rotary(
 | 
			
		||||
            x,
 | 
			
		||||
            cos,
 | 
			
		||||
            sin,
 | 
			
		||||
            seqlen_offsets=seqlen_offsets,
 | 
			
		||||
            cu_seqlens=cu_seqlens,
 | 
			
		||||
            max_seqlen=max_seqlen,
 | 
			
		||||
            interleaved=interleaved,
 | 
			
		||||
            inplace=inplace,
 | 
			
		||||
        )
 | 
			
		||||
        if isinstance(seqlen_offsets, int):
 | 
			
		||||
            # Can't save int with save_for_backward
 | 
			
		||||
            ctx.save_for_backward(cos, sin, cu_seqlens)
 | 
			
		||||
            ctx.seqlen_offsets = seqlen_offsets
 | 
			
		||||
        else:
 | 
			
		||||
            ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
 | 
			
		||||
            ctx.seqlen_offsets = None
 | 
			
		||||
        ctx.interleaved = interleaved
 | 
			
		||||
        ctx.inplace = inplace
 | 
			
		||||
        ctx.max_seqlen = max_seqlen
 | 
			
		||||
        return out if not inplace else x
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, do):
 | 
			
		||||
        seqlen_offsets = ctx.seqlen_offsets
 | 
			
		||||
        if seqlen_offsets is None:
 | 
			
		||||
            cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
 | 
			
		||||
        else:
 | 
			
		||||
            cos, sin, cu_seqlens = ctx.saved_tensors
 | 
			
		||||
        # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
 | 
			
		||||
        # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
 | 
			
		||||
        if not ctx.interleaved and not ctx.inplace:
 | 
			
		||||
            do = do.clone()
 | 
			
		||||
        dx = apply_rotary(
 | 
			
		||||
            do,
 | 
			
		||||
            cos,
 | 
			
		||||
            sin,
 | 
			
		||||
            seqlen_offsets=seqlen_offsets,
 | 
			
		||||
            cu_seqlens=cu_seqlens,
 | 
			
		||||
            max_seqlen=ctx.max_seqlen,
 | 
			
		||||
            interleaved=ctx.interleaved,
 | 
			
		||||
            inplace=ctx.inplace,
 | 
			
		||||
            conjugate=True,
 | 
			
		||||
        )
 | 
			
		||||
        return dx, None, None, None, None, None, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_rotary_emb(
 | 
			
		||||
    x,
 | 
			
		||||
    cos,
 | 
			
		||||
    sin,
 | 
			
		||||
    interleaved=False,
 | 
			
		||||
    inplace=False,
 | 
			
		||||
    seqlen_offsets: Union[int, torch.Tensor] = 0,
 | 
			
		||||
    cu_seqlens: Optional[torch.Tensor] = None,
 | 
			
		||||
    max_seqlen: Optional[int] = None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Arguments:
 | 
			
		||||
        x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
 | 
			
		||||
            else (total_seqlen, nheads, headdim)
 | 
			
		||||
        cos, sin: (seqlen_rotary, rotary_dim / 2)
 | 
			
		||||
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
 | 
			
		||||
            of 1st half and 2nd half (GPT-NeoX style).
 | 
			
		||||
        inplace: if True, apply rotary embedding in-place.
 | 
			
		||||
        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
 | 
			
		||||
            Most commonly used in inference when we have KV cache.
 | 
			
		||||
        cu_seqlens: (batch + 1,) or None
 | 
			
		||||
        max_seqlen: int
 | 
			
		||||
    Return:
 | 
			
		||||
        out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
 | 
			
		||||
            else (total_seqlen, nheads, headdim)
 | 
			
		||||
    rotary_dim must be <= headdim
 | 
			
		||||
    Apply rotary embedding to the first rotary_dim of x.
 | 
			
		||||
    """
 | 
			
		||||
    return ApplyRotaryEmb.apply(
 | 
			
		||||
        x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# For backward compatibility
 | 
			
		||||
apply_rotary_emb_func = apply_rotary_emb
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RotaryEmbedding(torch.nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    The rotary position embeddings from RoFormer_ (Su et. al).
 | 
			
		||||
    A crucial insight from the method is that the query and keys are
 | 
			
		||||
    transformed by rotation matrices which depend on the relative positions.
 | 
			
		||||
 | 
			
		||||
    Other implementations are available in the Rotary Transformer repo_ and in
 | 
			
		||||
    GPT-NeoX_, GPT-NeoX was an inspiration
 | 
			
		||||
 | 
			
		||||
    .. _RoFormer: https://arxiv.org/abs/2104.09864
 | 
			
		||||
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
 | 
			
		||||
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
 | 
			
		||||
 | 
			
		||||
    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
 | 
			
		||||
    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
 | 
			
		||||
    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        dim: int,
 | 
			
		||||
        base=10000.0,
 | 
			
		||||
        interleaved=False,
 | 
			
		||||
        scale_base=None,
 | 
			
		||||
        pos_idx_in_fp32=True,
 | 
			
		||||
        device=None,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
 | 
			
		||||
            of 1st half and 2nd half (GPT-NeoX style).
 | 
			
		||||
        pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
 | 
			
		||||
            otherwise they might be in lower precision.
 | 
			
		||||
            This option was added because previously (before 2023-07-02), when we construct
 | 
			
		||||
            the position indices, we use the dtype of self.inv_freq. In most cases this would
 | 
			
		||||
            be fp32, but if the model is trained in pure bf16 (not mixed precision), then
 | 
			
		||||
            self.inv_freq would be bf16, and the position indices are also in bf16.
 | 
			
		||||
            Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
 | 
			
		||||
            embeddings for some positions will coincide.
 | 
			
		||||
            To maintain compatibility with models previously trained in pure bf16,
 | 
			
		||||
            we add this option.
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.base = float(base)
 | 
			
		||||
        self.pos_idx_in_fp32 = pos_idx_in_fp32
 | 
			
		||||
        # Generate and save the inverse frequency buffer (non trainable)
 | 
			
		||||
        inv_freq = self._compute_inv_freq(device)
 | 
			
		||||
        self.register_buffer("inv_freq", inv_freq, persistent=False)
 | 
			
		||||
        self.interleaved = interleaved
 | 
			
		||||
        self.scale_base = scale_base
 | 
			
		||||
        scale = (
 | 
			
		||||
            (torch.arange(0, dim, 2, device=device,
 | 
			
		||||
             dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
 | 
			
		||||
            if scale_base is not None
 | 
			
		||||
            else None
 | 
			
		||||
        )
 | 
			
		||||
        self.register_buffer("scale", scale, persistent=False)
 | 
			
		||||
 | 
			
		||||
        self._seq_len_cached = 0
 | 
			
		||||
        self._cos_cached = None
 | 
			
		||||
        self._sin_cached = None
 | 
			
		||||
        self._cos_k_cached = None
 | 
			
		||||
        self._sin_k_cached = None
 | 
			
		||||
 | 
			
		||||
    def _compute_inv_freq(self, device=None):
 | 
			
		||||
        return 1.0 / (
 | 
			
		||||
            self.base
 | 
			
		||||
            ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
 | 
			
		||||
        # Reset the tables if the sequence length has changed,
 | 
			
		||||
        # if we're on a new device (possibly due to tracing for instance),
 | 
			
		||||
        # or if we're switching from inference mode to training
 | 
			
		||||
        if (
 | 
			
		||||
            seqlen > self._seq_len_cached
 | 
			
		||||
            or self._cos_cached is None
 | 
			
		||||
            or self._cos_cached.device != device
 | 
			
		||||
            or self._cos_cached.dtype != dtype
 | 
			
		||||
            or (self.training and self._cos_cached.is_inference())
 | 
			
		||||
        ):
 | 
			
		||||
            self._seq_len_cached = seqlen
 | 
			
		||||
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
 | 
			
		||||
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
 | 
			
		||||
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
 | 
			
		||||
            if self.pos_idx_in_fp32:
 | 
			
		||||
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
 | 
			
		||||
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
 | 
			
		||||
                # will be large. Having it in bf16 will lose a lot of precision and cause the
 | 
			
		||||
                # cos & sin output to change significantly.
 | 
			
		||||
                # We want to recompute self.inv_freq if it was not loaded in fp32
 | 
			
		||||
                if self.inv_freq.dtype != torch.float32:
 | 
			
		||||
                    inv_freq = self._compute_inv_freq(device=device)
 | 
			
		||||
                else:
 | 
			
		||||
                    inv_freq = self.inv_freq
 | 
			
		||||
            else:
 | 
			
		||||
                t = torch.arange(seqlen, device=device,
 | 
			
		||||
                                 dtype=self.inv_freq.dtype)
 | 
			
		||||
                inv_freq = self.inv_freq
 | 
			
		||||
            # Don't do einsum, it converts fp32 to fp16 under AMP
 | 
			
		||||
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
 | 
			
		||||
            freqs = torch.outer(t, inv_freq)
 | 
			
		||||
            if self.scale is None:
 | 
			
		||||
                self._cos_cached = torch.cos(freqs).to(dtype)
 | 
			
		||||
                self._sin_cached = torch.sin(freqs).to(dtype)
 | 
			
		||||
            else:
 | 
			
		||||
                power = (
 | 
			
		||||
                    torch.arange(seqlen, dtype=self.scale.dtype,
 | 
			
		||||
                                 device=self.scale.device)
 | 
			
		||||
                    - seqlen // 2
 | 
			
		||||
                ) / self.scale_base
 | 
			
		||||
                scale = self.scale.to(
 | 
			
		||||
                    device=power.device) ** rearrange(power, "s -> s 1")
 | 
			
		||||
                # We want the multiplication by scale to happen in fp32
 | 
			
		||||
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
 | 
			
		||||
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
 | 
			
		||||
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
 | 
			
		||||
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        q: torch.Tensor,
 | 
			
		||||
        k: torch.Tensor,
 | 
			
		||||
        seqlen_offset: Union[int, torch.Tensor] = 0,
 | 
			
		||||
        max_seqlen: Optional[int] = None,
 | 
			
		||||
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 | 
			
		||||
        """
 | 
			
		||||
        qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
 | 
			
		||||
             else it's just q of shape (batch, seqlen, nheads, headdim)
 | 
			
		||||
        kv: (batch, seqlen, 2, nheads, headdim)
 | 
			
		||||
        seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
 | 
			
		||||
            Most commonly used in inference when we have KV cache.
 | 
			
		||||
            If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
 | 
			
		||||
            should pass in max_seqlen, which will update the cos / sin cache up to that length.
 | 
			
		||||
        Apply rotary embedding *inplace* to qkv and / or kv.
 | 
			
		||||
        """
 | 
			
		||||
        seqlen = q.shape[1]
 | 
			
		||||
        if max_seqlen is not None:
 | 
			
		||||
            self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype)
 | 
			
		||||
        elif isinstance(seqlen_offset, int):
 | 
			
		||||
            self._update_cos_sin_cache(seqlen + seqlen_offset, device=q.device, dtype=q.dtype)
 | 
			
		||||
        if self.scale is None:
 | 
			
		||||
            q = apply_rotary_emb_func(
 | 
			
		||||
                q,
 | 
			
		||||
                self._cos_cached,
 | 
			
		||||
                self._sin_cached,
 | 
			
		||||
                interleaved=self.interleaved,
 | 
			
		||||
                seqlen_offsets=seqlen_offset,
 | 
			
		||||
            )
 | 
			
		||||
            k = apply_rotary_emb_func(
 | 
			
		||||
                k,
 | 
			
		||||
                self._cos_cached,
 | 
			
		||||
                self._sin_cached,
 | 
			
		||||
                interleaved=self.interleaved,
 | 
			
		||||
                seqlen_offsets=seqlen_offset,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            q = apply_rotary_emb_func(
 | 
			
		||||
                q,
 | 
			
		||||
                self._cos_cached,
 | 
			
		||||
                self._sin_cached,
 | 
			
		||||
                interleaved=self.interleaved,
 | 
			
		||||
                seqlen_offsets=seqlen_offset,
 | 
			
		||||
            )
 | 
			
		||||
            k = apply_rotary_emb_func(
 | 
			
		||||
                k,
 | 
			
		||||
                self._cos_k_cached,
 | 
			
		||||
                self._sin_k_cached,
 | 
			
		||||
                interleaved=self.interleaved,
 | 
			
		||||
                seqlen_offsets=seqlen_offset,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return q, k
 | 
			
		||||
							
								
								
									
										18
									
								
								finetune/lora/v6/fla/ops/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								finetune/lora/v6/fla/ops/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,18 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .based import fused_chunk_based, parallel_based
 | 
			
		||||
from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
 | 
			
		||||
from .retention import (chunk_retention, fused_chunk_retention,
 | 
			
		||||
                        fused_recurrent_retention, parallel_retention)
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'fused_chunk_based',
 | 
			
		||||
    'parallel_based',
 | 
			
		||||
    'chunk_gla',
 | 
			
		||||
    'fused_chunk_gla',
 | 
			
		||||
    'fused_recurrent_gla',
 | 
			
		||||
    'chunk_retention',
 | 
			
		||||
    'fused_chunk_retention',
 | 
			
		||||
    'fused_recurrent_retention',
 | 
			
		||||
    'parallel_retention'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										11
									
								
								finetune/lora/v6/fla/ops/abc/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								finetune/lora/v6/fla/ops/abc/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .chunk import chunk_abc
 | 
			
		||||
from .chunk_gate import chunk_gated_abc
 | 
			
		||||
from .recurrent_fuse import fused_recurrent_gated_abc
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'chunk_abc',
 | 
			
		||||
    'chunk_gated_abc',
 | 
			
		||||
    'fused_recurrent_gated_abc'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										1194
									
								
								finetune/lora/v6/fla/ops/abc/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1194
									
								
								finetune/lora/v6/fla/ops/abc/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										1287
									
								
								finetune/lora/v6/fla/ops/abc/chunk_gate.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1287
									
								
								finetune/lora/v6/fla/ops/abc/chunk_gate.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										90
									
								
								finetune/lora/v6/fla/ops/abc/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								finetune/lora/v6/fla/ops/abc/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def naive_recurrent_abc(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    s: torch.Tensor,
 | 
			
		||||
    g: Optional[torch.Tensor] = None,
 | 
			
		||||
    scale: Optional[int] = None,
 | 
			
		||||
    initial_state: Optional[torch.Tensor] = None,
 | 
			
		||||
    output_final_state: Optional[bool] = False
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    dtype = q.dtype
 | 
			
		||||
 | 
			
		||||
    # [batch_size, n_heads, seq_len, n_slots]
 | 
			
		||||
    if g is None:
 | 
			
		||||
        z = s.float().logcumsumexp(2)
 | 
			
		||||
        g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
 | 
			
		||||
        s = torch.exp(s - z)
 | 
			
		||||
    q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
 | 
			
		||||
    B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
 | 
			
		||||
 | 
			
		||||
    hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
 | 
			
		||||
    ok = torch.zeros_like(s)
 | 
			
		||||
 | 
			
		||||
    if scale is None:
 | 
			
		||||
        scale = q.shape[-1] ** -0.5
 | 
			
		||||
 | 
			
		||||
    final_state = None
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        hk += initial_state[0]
 | 
			
		||||
 | 
			
		||||
    for i in range(T):
 | 
			
		||||
        q_i = q[:, :, i] * scale
 | 
			
		||||
        k_i = k[:, :, i]
 | 
			
		||||
        v_i = s[:, :, i]
 | 
			
		||||
        g_i = g[:, :, i].exp()
 | 
			
		||||
        hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
 | 
			
		||||
        ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
 | 
			
		||||
 | 
			
		||||
    qv = ok.softmax(-1)
 | 
			
		||||
    hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
 | 
			
		||||
    ov = torch.zeros_like(v)
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        hv += initial_state[1]
 | 
			
		||||
 | 
			
		||||
    for i in range(T):
 | 
			
		||||
        q_i = qv[:, :, i]
 | 
			
		||||
        k_i = s[:, :, i]
 | 
			
		||||
        v_i = v[:, :, i]
 | 
			
		||||
        g_i = g[:, :, i].exp()
 | 
			
		||||
        hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
 | 
			
		||||
        ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
 | 
			
		||||
 | 
			
		||||
    if output_final_state:
 | 
			
		||||
        final_state = (hk, hv)
 | 
			
		||||
    return ov.to(dtype), final_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def naive_cumsum_abc(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    s: torch.Tensor
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper.
 | 
			
		||||
    This is just for demonstration purposes, with no numerical stabilities guaranteed.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    dtype = q.dtype
 | 
			
		||||
    q, k, v, s = map(lambda x: x.float(), (q, k, v, s))
 | 
			
		||||
 | 
			
		||||
    scale = q.shape[-1] ** -0.5
 | 
			
		||||
    # [batch_size, n_heads, seq_len, n_slots]
 | 
			
		||||
    s = (s - s.max(2, True)[0]).exp()
 | 
			
		||||
    z = s.cumsum(2)
 | 
			
		||||
    # [batch_size, n_heads, seq_len, n_slots, d_head]
 | 
			
		||||
    K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
 | 
			
		||||
    V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
 | 
			
		||||
    # [batch_size, n_heads, seq_len, n_slots]
 | 
			
		||||
    p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1)
 | 
			
		||||
    # [batch_size, n_heads, seq_len, d_head]
 | 
			
		||||
    o = torch.einsum('...m,...md->...d', p, V)
 | 
			
		||||
    return o.to(dtype), None
 | 
			
		||||
							
								
								
									
										388
									
								
								finetune/lora/v6/fla/ops/abc/recurrent_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										388
									
								
								finetune/lora/v6/fla/ops/abc/recurrent_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,388 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2024, Yu Zhang, Songlin Yang
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_recurrent_gated_abc_fwd_kernel(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    gk,
 | 
			
		||||
    gv,
 | 
			
		||||
    o,
 | 
			
		||||
    h0,
 | 
			
		||||
    ht,
 | 
			
		||||
    s_k_h,
 | 
			
		||||
    s_v_h,
 | 
			
		||||
    scale,
 | 
			
		||||
    B: tl.constexpr,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    STORE_FINAL_STATE: tl.constexpr,
 | 
			
		||||
    REVERSE: tl.constexpr,
 | 
			
		||||
    USE_GK: tl.constexpr,
 | 
			
		||||
    USE_GV: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    # indices
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
 | 
			
		||||
    p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
 | 
			
		||||
    p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
 | 
			
		||||
    p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
 | 
			
		||||
 | 
			
		||||
    if USE_GK:
 | 
			
		||||
        p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
 | 
			
		||||
    if USE_GV:
 | 
			
		||||
        p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
 | 
			
		||||
 | 
			
		||||
    mask_bk = (i_k * BK + tl.arange(0, BK)) < K
 | 
			
		||||
    mask_bv = (i_v * BV + tl.arange(0, BV)) < V
 | 
			
		||||
 | 
			
		||||
    h = tl.zeros([BV, BK], dtype=tl.float32)
 | 
			
		||||
    mask_kv = mask_bk[None, :] & mask_bv[:, None]
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
 | 
			
		||||
        h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    for _ in range(0, T):
 | 
			
		||||
        b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
 | 
			
		||||
        b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
        b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
            h = h * b_gk[None, :]
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
            h = h * b_gv[:, None]
 | 
			
		||||
        h += b_k[None, :] * b_v[:, None]
 | 
			
		||||
        b_o = h * b_q[None, :]
 | 
			
		||||
        b_o = tl.sum(b_o, axis=1)
 | 
			
		||||
        tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
 | 
			
		||||
        p_q += -K if REVERSE else K
 | 
			
		||||
        p_k += -K if REVERSE else K
 | 
			
		||||
        p_o += -V if REVERSE else V
 | 
			
		||||
        p_v += -V if REVERSE else V
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            p_gk += -K if REVERSE else K
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            p_gv += -V if REVERSE else V
 | 
			
		||||
 | 
			
		||||
    if STORE_FINAL_STATE:
 | 
			
		||||
        p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
 | 
			
		||||
        tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_recurrent_gated_abc_bwd_kernel(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    gk,
 | 
			
		||||
    gv,
 | 
			
		||||
    do,
 | 
			
		||||
    dq,
 | 
			
		||||
    dk,
 | 
			
		||||
    dv,
 | 
			
		||||
    h0,
 | 
			
		||||
    s_k_h,
 | 
			
		||||
    s_v_h,
 | 
			
		||||
    scale,
 | 
			
		||||
    B: tl.constexpr,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    REVERSE: tl.constexpr,
 | 
			
		||||
    USE_GK: tl.constexpr,
 | 
			
		||||
    USE_GV: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
 | 
			
		||||
    p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
 | 
			
		||||
    p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
 | 
			
		||||
    p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
 | 
			
		||||
    p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
 | 
			
		||||
    if USE_GK:
 | 
			
		||||
        p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
 | 
			
		||||
    if USE_GV:
 | 
			
		||||
        p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
 | 
			
		||||
    mask_bk = i_k * BK + tl.arange(0, BK) < K
 | 
			
		||||
    mask_bv = i_v * BV + tl.arange(0, BV) < V
 | 
			
		||||
    mask_kv = mask_bk[:, None] & mask_bv[None, :]
 | 
			
		||||
    h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
 | 
			
		||||
        h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    for _ in range(0, T):
 | 
			
		||||
        b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
        b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
            h = h * b_gk[:, None]
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
            h = h * b_gv[None, :]
 | 
			
		||||
        h += b_k[:, None] * b_v[None, :]
 | 
			
		||||
        b_dq = tl.sum(h * b_do[None, :], axis=1) * scale
 | 
			
		||||
        tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)
 | 
			
		||||
 | 
			
		||||
        p_k += -K if REVERSE else K
 | 
			
		||||
        p_v += -V if REVERSE else V
 | 
			
		||||
        p_q += -K if REVERSE else K
 | 
			
		||||
        p_do += -V if REVERSE else V
 | 
			
		||||
        p_dq += -K if REVERSE else K
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            p_gk += -K if REVERSE else K
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            p_gv += -V if REVERSE else V
 | 
			
		||||
 | 
			
		||||
    # sync threads
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
 | 
			
		||||
    p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
 | 
			
		||||
    p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
 | 
			
		||||
    p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
 | 
			
		||||
    p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
 | 
			
		||||
    p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
 | 
			
		||||
    if USE_GK:
 | 
			
		||||
        p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
 | 
			
		||||
    if USE_GV:
 | 
			
		||||
        p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
 | 
			
		||||
 | 
			
		||||
    b_dh = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
    for _ in range(T):
 | 
			
		||||
        b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
 | 
			
		||||
        b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
        b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        b_dh += b_q[:, None] * b_do[None, :]
 | 
			
		||||
        b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
 | 
			
		||||
        b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
            b_dh *= b_gk[:, None]
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
            b_dh *= b_gv[None, :]
 | 
			
		||||
        tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
 | 
			
		||||
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
 | 
			
		||||
 | 
			
		||||
        p_q += K if REVERSE else -K
 | 
			
		||||
        p_k += K if REVERSE else -K
 | 
			
		||||
        p_v += V if REVERSE else -V
 | 
			
		||||
        p_do += V if REVERSE else -V
 | 
			
		||||
        p_dk += K if REVERSE else -K
 | 
			
		||||
        p_dv += V if REVERSE else -V
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            p_gk += K if REVERSE else -K
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            p_gv += V if REVERSE else -V
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedRecurrentGatedABCFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, q, k, v, s, g, scale=None, initial_state=None, output_final_state=False, reverse=False):
 | 
			
		||||
        B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
 | 
			
		||||
        # default scale
 | 
			
		||||
        if scale is None:
 | 
			
		||||
            scale = K ** -0.5
 | 
			
		||||
 | 
			
		||||
        BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)
 | 
			
		||||
        NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 1
 | 
			
		||||
 | 
			
		||||
        g = g.float().exp()
 | 
			
		||||
 | 
			
		||||
        final_state = (None, None)
 | 
			
		||||
        if output_final_state:
 | 
			
		||||
            final_state = (q.new_empty(B, H, K, M), q.new_empty(B, H, M, V))
 | 
			
		||||
 | 
			
		||||
        ok = q.new_empty(NK, B, H, T, M, dtype=torch.float)
 | 
			
		||||
        gk, gv = None, g
 | 
			
		||||
        grid = (NM, NK, B * H)
 | 
			
		||||
        fused_recurrent_gated_abc_fwd_kernel[grid](
 | 
			
		||||
            q, k, s, gk, gv, ok, initial_state[0], final_state[0],
 | 
			
		||||
            k.stride(1),
 | 
			
		||||
            s.stride(1),
 | 
			
		||||
            scale=scale,
 | 
			
		||||
            B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state[0] is not None,
 | 
			
		||||
            STORE_FINAL_STATE=final_state[0] is not None,
 | 
			
		||||
            USE_GK=False,
 | 
			
		||||
            USE_GV=True,
 | 
			
		||||
            REVERSE=reverse,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
        ok = ok.sum(0)
 | 
			
		||||
 | 
			
		||||
        qv = ok.softmax(-1, dtype=torch.float)
 | 
			
		||||
        ov = q.new_empty(NM, B, H, T, V, dtype=torch.float)
 | 
			
		||||
        gk, gv = g, None
 | 
			
		||||
        grid = (NV, NM, B * H)
 | 
			
		||||
        fused_recurrent_gated_abc_fwd_kernel[grid](
 | 
			
		||||
            qv, s, v, gk, gv, ov, initial_state[1], final_state[1],
 | 
			
		||||
            s.stride(1),
 | 
			
		||||
            v.stride(1),
 | 
			
		||||
            scale=1.,
 | 
			
		||||
            B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state[0] is not None,
 | 
			
		||||
            STORE_FINAL_STATE=final_state[0] is not None,
 | 
			
		||||
            USE_GK=True,
 | 
			
		||||
            USE_GV=False,
 | 
			
		||||
            REVERSE=reverse,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
        ov = ov.sum(0)
 | 
			
		||||
 | 
			
		||||
        ctx.save_for_backward(q, k, v, s, g, qv, *initial_state, ok)
 | 
			
		||||
        ctx.scale = scale
 | 
			
		||||
        ctx.reverse = reverse
 | 
			
		||||
        # we do not need the gradient of the final state from the next chunk
 | 
			
		||||
        # similiar to Trunctated BPTT
 | 
			
		||||
        if final_state is not None:
 | 
			
		||||
            final_state = tuple(i.detach() for i in final_state)
 | 
			
		||||
        return ov.to(q.dtype), final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    def backward(ctx, do, dht=None):
 | 
			
		||||
        q, k, v, s, g, qv, *initial_state, ok = ctx.saved_tensors
 | 
			
		||||
        B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
 | 
			
		||||
        V = v.shape[-1]
 | 
			
		||||
        scale = ctx.scale
 | 
			
		||||
 | 
			
		||||
        BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)
 | 
			
		||||
        NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 1
 | 
			
		||||
 | 
			
		||||
        dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
 | 
			
		||||
        dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
 | 
			
		||||
        dv = q.new_empty(NM, B, H, T, V, dtype=torch.float)
 | 
			
		||||
        gk, gv = g, None
 | 
			
		||||
        grid = (NV, NM, B * H)
 | 
			
		||||
        fused_recurrent_gated_abc_bwd_kernel[grid](
 | 
			
		||||
            qv, s, v, gk, gv, do, dqv, dsv, dv, initial_state[1],
 | 
			
		||||
            s.stride(1),
 | 
			
		||||
            v.stride(1),
 | 
			
		||||
            scale=1.,
 | 
			
		||||
            B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state[1] is not None,
 | 
			
		||||
            REVERSE=ctx.reverse,
 | 
			
		||||
            USE_GK=gk is not None,
 | 
			
		||||
            USE_GV=gv is not None
 | 
			
		||||
        )
 | 
			
		||||
        dqv = dqv.sum(0)
 | 
			
		||||
        dsv = dsv.sum(0)
 | 
			
		||||
        dv = dv.sum(0)
 | 
			
		||||
        dgk = dqv * qv.float() - dsv * s.float()
 | 
			
		||||
        dgk_cumsum = dgk.cumsum(-2)
 | 
			
		||||
        dgk = dgk + dgk_cumsum[:, :, -1, None] - dgk_cumsum
 | 
			
		||||
 | 
			
		||||
        dok = qv * (dqv - (qv * dqv).sum(-1, True))
 | 
			
		||||
        dq = q.new_empty(NM, B, H, T, K, dtype=torch.float)
 | 
			
		||||
        dk = q.new_empty(NM, B, H, T, K, dtype=torch.float)
 | 
			
		||||
        dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float)
 | 
			
		||||
        gk, gv = None, g
 | 
			
		||||
        grid = (NM, NK, B * H)
 | 
			
		||||
        fused_recurrent_gated_abc_bwd_kernel[grid](
 | 
			
		||||
            q, k, s, gk, gv, dok, dq, dk, dsk, initial_state[0],
 | 
			
		||||
            q.stride(1),
 | 
			
		||||
            s.stride(1),
 | 
			
		||||
            scale=scale,
 | 
			
		||||
            B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state[0] is not None,
 | 
			
		||||
            REVERSE=ctx.reverse,
 | 
			
		||||
            USE_GK=gk is not None,
 | 
			
		||||
            USE_GV=gv is not None
 | 
			
		||||
        )
 | 
			
		||||
        dq = dq.sum(0)
 | 
			
		||||
        dk = dk.sum(0)
 | 
			
		||||
        dsk = dsk.sum(0)
 | 
			
		||||
 | 
			
		||||
        dgv = dok.float() * ok.float() - dsk * s.float()
 | 
			
		||||
        dgv_cumsum = dgv.cumsum(-2)
 | 
			
		||||
        dgv = dgv + dgv_cumsum[:, :, -1, None] - dgv_cumsum
 | 
			
		||||
 | 
			
		||||
        ds = dsk.add_(dsv)
 | 
			
		||||
        dg = dgk.add_(dgv)
 | 
			
		||||
 | 
			
		||||
        return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, None, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fused_recurrent_gated_abc(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    s: torch.Tensor,
 | 
			
		||||
    g: Optional[torch.Tensor] = None,
 | 
			
		||||
    scale: Optional[int] = None,
 | 
			
		||||
    initial_state: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
    output_final_state: Optional[bool] = False
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Args:
 | 
			
		||||
        q (torch.Tensor):
 | 
			
		||||
            queries of shape `(B, H, T, K)`
 | 
			
		||||
        k (torch.Tensor):
 | 
			
		||||
            keys of shape `(B, H, T, K)`
 | 
			
		||||
        v (torch.Tensor):
 | 
			
		||||
            values of shape `(B, H, T, V)`
 | 
			
		||||
        g (torch.Tensor):
 | 
			
		||||
            Forget gates of shape `(B, H, T, M)` applied to keys.
 | 
			
		||||
            If not provided, this function is equivalent to vanilla ABC.
 | 
			
		||||
        scale (Optional[int]):
 | 
			
		||||
            Scale factor for attention scores.
 | 
			
		||||
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
 | 
			
		||||
        initial_state (Optional[Tuple[torch.Tensor]]):
 | 
			
		||||
            Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`.
 | 
			
		||||
        output_final_state (Optional[bool]):
 | 
			
		||||
            Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`.
 | 
			
		||||
    """
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = tuple(i.detach() for i in initial_state)
 | 
			
		||||
    if g is None:
 | 
			
		||||
        # TODO: this 3 steps took huge amount of time, ought to be optimized
 | 
			
		||||
        z = s.float().logcumsumexp(2)
 | 
			
		||||
        g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
 | 
			
		||||
        s = torch.exp(s - z).to(k.dtype)
 | 
			
		||||
    if scale is None:
 | 
			
		||||
        scale = q.shape[-1] ** -0.5
 | 
			
		||||
    ov, final_state = FusedRecurrentGatedABCFunction.apply(q, k, v, s, g, scale, initial_state, output_final_state)
 | 
			
		||||
    return ov, final_state
 | 
			
		||||
							
								
								
									
										9
									
								
								finetune/lora/v6/fla/ops/based/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								finetune/lora/v6/fla/ops/based/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .chunk_fuse import fused_chunk_based
 | 
			
		||||
from .parallel import parallel_based
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'fused_chunk_based',
 | 
			
		||||
    'parallel_based'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										410
									
								
								finetune/lora/v6/fla/ops/based/chunk_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										410
									
								
								finetune/lora/v6/fla/ops/based/chunk_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,410 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
# on-the-fly computation without materializing hidden statets into HBMs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_chunk_based_fwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_V]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V]
 | 
			
		||||
    o,  # output [B, H, L, D_head_V]
 | 
			
		||||
    z,  # normalizer [B, H, L, 1]
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
    B,  # batch size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BT: tl.constexpr,  # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
):
 | 
			
		||||
    # indices
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    # [BT, BT]
 | 
			
		||||
    m_s = o_i[:, None] >= o_i[None, :]
 | 
			
		||||
 | 
			
		||||
    # [BV], zero-order taylor expansion
 | 
			
		||||
    b_h_0o = tl.zeros([BV], dtype=tl.float32)
 | 
			
		||||
    # [BK, BV], first-order taylor expansion
 | 
			
		||||
    b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
    # [BK, BK, BV] second-order taylor expansion
 | 
			
		||||
    b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    # make block pointers
 | 
			
		||||
    p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
 | 
			
		||||
                            (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
 | 
			
		||||
                            (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
 | 
			
		||||
                            (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV),
 | 
			
		||||
                            (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
 | 
			
		||||
    p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
 | 
			
		||||
    k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
 | 
			
		||||
    k_1o = tl.zeros([1, BK], dtype=tl.float32)
 | 
			
		||||
    k_0o = 0
 | 
			
		||||
 | 
			
		||||
    for i in range(0, tl.cdiv(T, BT)):
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BK*BK, BT]
 | 
			
		||||
        b_k_2o = b_k[:, None, :] * b_k[None, :, :]
 | 
			
		||||
        b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
 | 
			
		||||
        b_o = tl.zeros([BT, BV], dtype=tl.float32)
 | 
			
		||||
        b_z = tl.zeros([BT], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
        # interchunk
 | 
			
		||||
        # zero-order
 | 
			
		||||
        b_o += b_h_0o
 | 
			
		||||
        b_z += k_0o
 | 
			
		||||
        # first-order
 | 
			
		||||
        b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
 | 
			
		||||
        b_z += tl.sum(b_q * k_1o, axis=1)
 | 
			
		||||
        # second-order
 | 
			
		||||
        b_q_2o = b_q[:, :, None] * b_q[:, None, :]
 | 
			
		||||
        b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
 | 
			
		||||
        b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
 | 
			
		||||
        b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
 | 
			
		||||
 | 
			
		||||
        # update running statistics
 | 
			
		||||
        k_1o += tl.sum(b_k, axis=1)[None, :]
 | 
			
		||||
        k_2o += tl.sum(b_k_2o, axis=1)[None, :]
 | 
			
		||||
        k_0o += BT
 | 
			
		||||
 | 
			
		||||
        # intrachunk
 | 
			
		||||
        # [BT, BT]
 | 
			
		||||
        b_s = tl.dot(b_q, b_k, allow_tf32=False)
 | 
			
		||||
        b_s = 1 + b_s + 0.5 * b_s * b_s
 | 
			
		||||
        b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
        b_z += tl.sum(b_s, axis=1)
 | 
			
		||||
        b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
 | 
			
		||||
        # [TB, BV]
 | 
			
		||||
        tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        tl.store(p_z, b_z.to(p_z.dtype.element_ty),
 | 
			
		||||
                 mask=(i * BT + tl.arange(0, BT)) < T)
 | 
			
		||||
 | 
			
		||||
        # update hidden state
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
 | 
			
		||||
        b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
 | 
			
		||||
        b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
 | 
			
		||||
 | 
			
		||||
        p_q = tl.advance(p_q, (BT, 0))
 | 
			
		||||
        p_k = tl.advance(p_k, (0, BT))
 | 
			
		||||
        p_v = tl.advance(p_v, (BT, 0))
 | 
			
		||||
        p_o = tl.advance(p_o, (BT, 0))
 | 
			
		||||
        p_z += BT
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_chunk_based_bwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    # NV: number of split in the V dimension. NK: number of split in the K dimension
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_V]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V]
 | 
			
		||||
    do,  # gradient of output [B, H, L, D_head_V]
 | 
			
		||||
    dz,  # gradient of normalizer [B, H, L]
 | 
			
		||||
    dq,  # gradient of query [NV, B, H, L, D_head_K]
 | 
			
		||||
    dk,  # gradient of key [NV, B, H, L, D_head_K]
 | 
			
		||||
    dv,  # gradient of value [NK, B, H, L, D_head_V]
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
    B,  # batch_size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BT: tl.constexpr,  # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
    m_s = o_i[:, None] >= o_i[None, :]
 | 
			
		||||
 | 
			
		||||
    # [BV], zero-order taylor expansion
 | 
			
		||||
    # b_h_0o = tl.zeros([BV], dtype=tl.float32)
 | 
			
		||||
    # [BK, BV], first-order taylor expansion
 | 
			
		||||
    b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
 | 
			
		||||
    # [BK, BK, BV] second-order taylor expansion
 | 
			
		||||
    b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    k_1o = tl.zeros([1, BK], dtype=tl.float32)
 | 
			
		||||
    k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i in range(0, tl.cdiv(T, BT)):
 | 
			
		||||
        p_q = tl.make_block_ptr(
 | 
			
		||||
            q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_k = tl.make_block_ptr(
 | 
			
		||||
            k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_v = tl.make_block_ptr(
 | 
			
		||||
            v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(
 | 
			
		||||
            do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h,
 | 
			
		||||
                                 (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
 | 
			
		||||
        b_dq = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
        # load tensors
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_q = (b_q * scale).to(b_q.dtype)
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BV, BT]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        # inter-chunk
 | 
			
		||||
        b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_dq += b_dz[:, None] * k_1o
 | 
			
		||||
        b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
 | 
			
		||||
        b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
 | 
			
		||||
        b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
 | 
			
		||||
        b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
 | 
			
		||||
        b_dq *= scale
 | 
			
		||||
 | 
			
		||||
        # intra-chunk
 | 
			
		||||
        # [BT, BT]
 | 
			
		||||
        b_ds = tl.dot(b_do, b_v, allow_tf32=False)
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_ds += b_dz[:, None]
 | 
			
		||||
        b_ds = tl.where(m_s, b_ds, 0) * scale
 | 
			
		||||
        b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
 | 
			
		||||
        b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
        b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
        # store
 | 
			
		||||
        tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        # update hidden state
 | 
			
		||||
        # [BT, BK*BK]
 | 
			
		||||
        b_k_2o = b_k[:, :, None] * b_k[:, None, :]
 | 
			
		||||
        b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
 | 
			
		||||
        # [BV, BK*BK]
 | 
			
		||||
        b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
 | 
			
		||||
        # [BV, BK]
 | 
			
		||||
        b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            # update running statistics
 | 
			
		||||
            k_1o += tl.sum(b_k, axis=0)[None, :]
 | 
			
		||||
            k_2o += tl.sum(b_k_2o, axis=0)[None, :]
 | 
			
		||||
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
    b_h_1o = None
 | 
			
		||||
    b_h_2o = None
 | 
			
		||||
 | 
			
		||||
    # [BK, BV], first-order taylor expansion
 | 
			
		||||
    b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
    # [BK, BK, BV] second-order taylor expansion
 | 
			
		||||
    b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
 | 
			
		||||
    b_dh_0o = tl.zeros([BV], dtype=tl.float32)
 | 
			
		||||
    m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
 | 
			
		||||
 | 
			
		||||
    dq_1o = tl.zeros([1, BK], dtype=tl.float32)
 | 
			
		||||
    dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
 | 
			
		||||
        p_q = tl.make_block_ptr(
 | 
			
		||||
            q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))
 | 
			
		||||
        p_k = tl.make_block_ptr(
 | 
			
		||||
            k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_v = tl.make_block_ptr(
 | 
			
		||||
            v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_do = tl.make_block_ptr(
 | 
			
		||||
            do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK),
 | 
			
		||||
                                 (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV),
 | 
			
		||||
                                 (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
 | 
			
		||||
 | 
			
		||||
        b_dk = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
        b_dv = tl.zeros([BT, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
        b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
 | 
			
		||||
        b_q = (b_q * scale).to(b_k.dtype)
 | 
			
		||||
 | 
			
		||||
        # intra chunk
 | 
			
		||||
        b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_ds += b_dz[None, :]
 | 
			
		||||
        b_ds = tl.where(m_s, b_ds, 0)
 | 
			
		||||
        b_s = tl.dot(b_k, b_q, allow_tf32=False)
 | 
			
		||||
        b_s2 = 1 + b_s + 0.5 * b_s * b_s
 | 
			
		||||
        b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
        b_s2 = tl.where(m_s, b_s2, 0)
 | 
			
		||||
        b_ds *= (1+b_s)
 | 
			
		||||
 | 
			
		||||
        b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
 | 
			
		||||
        b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
        # inter chunk
 | 
			
		||||
        b_k_2o = b_k[:, :, None] * b_k[:, None, :]
 | 
			
		||||
        b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
 | 
			
		||||
 | 
			
		||||
        b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
        b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
        b_dv += b_dh_0o
 | 
			
		||||
 | 
			
		||||
        b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_dk += dq_1o
 | 
			
		||||
 | 
			
		||||
        b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype),
 | 
			
		||||
                         tl.trans(b_v), allow_tf32=False)
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_dk_2o += dq_2o
 | 
			
		||||
        b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
 | 
			
		||||
        b_k_fp32 = tl.trans(b_k.to(tl.float32))
 | 
			
		||||
        b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
 | 
			
		||||
        b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
 | 
			
		||||
        b_dk += tl.trans(b_dk2)
 | 
			
		||||
 | 
			
		||||
        # hidden state update
 | 
			
		||||
        b_dh_0o += tl.sum(b_do, axis=0)
 | 
			
		||||
        b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
 | 
			
		||||
        b_q_2o = b_q[None, :, :] * b_q[:, None, :]
 | 
			
		||||
        b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
 | 
			
		||||
        b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
 | 
			
		||||
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
 | 
			
		||||
            dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
 | 
			
		||||
 | 
			
		||||
        tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedChunkBasedFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, q, k, v, scale=1):
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        # assert d_head_qk == 16, "currently we do not support feature dim other than 16"
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
 | 
			
		||||
        scale = scale
 | 
			
		||||
        BT = 16
 | 
			
		||||
        BK, BV = min(d_head_qk, 16), min(d_head_v, 32)
 | 
			
		||||
        BK, BV = max(BK, 16), max(BV, 16)
 | 
			
		||||
        NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
 | 
			
		||||
        num_warps = 4
 | 
			
		||||
 | 
			
		||||
        # the norm of o might explode, so we need to use float32 here
 | 
			
		||||
        o = q.new_empty(NK, batch_size, n_heads, seq_len,
 | 
			
		||||
                        d_head_v, dtype=torch.float32)
 | 
			
		||||
        z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
        grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
        fused_chunk_based_fwd_kernel[grid](
 | 
			
		||||
            q, k, v, o, z,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
        )
 | 
			
		||||
        o = o.sum(0)
 | 
			
		||||
        z = z.sum(0)
 | 
			
		||||
        ctx.save_for_backward(q, k, v)
 | 
			
		||||
        ctx.scale = scale
 | 
			
		||||
        return o.to(q.dtype), z.to(z.dtype)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    def backward(ctx, do, dz):
 | 
			
		||||
        q, k, v = ctx.saved_tensors
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
        scale = ctx.scale
 | 
			
		||||
 | 
			
		||||
        BT = 16
 | 
			
		||||
        BK, BV = min(d_head_qk, 16), min(d_head_v, 32)
 | 
			
		||||
        BK, BV = max(BK, 16), max(BV, 16)
 | 
			
		||||
        NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 4
 | 
			
		||||
 | 
			
		||||
        dq = q.new_empty(NV, batch_size, n_heads,  seq_len, d_head_qk)
 | 
			
		||||
        dk = q.new_empty(NV, batch_size, n_heads,  seq_len, d_head_qk)
 | 
			
		||||
        dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
 | 
			
		||||
        grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
 | 
			
		||||
        fused_chunk_based_bwd_kernel[grid](
 | 
			
		||||
            q, k, v, do, dz, dq, dk, dv,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
        dq = dq.sum(0)
 | 
			
		||||
        dk = dk.sum(0)
 | 
			
		||||
        dv = dv.sum(0)
 | 
			
		||||
        return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
triton_fused_chunk_based = FusedChunkBasedFunction.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):
 | 
			
		||||
    assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
 | 
			
		||||
    if use_scale:
 | 
			
		||||
        scale = q.shape[-1] ** -0.5
 | 
			
		||||
    else:
 | 
			
		||||
        scale = 1
 | 
			
		||||
    o, z = triton_fused_chunk_based(q, k, v, scale)
 | 
			
		||||
    if use_normalize:
 | 
			
		||||
        o = o / (z[..., None] + 1e-6)
 | 
			
		||||
    else:
 | 
			
		||||
        o = o
 | 
			
		||||
 | 
			
		||||
    return o.to(q.dtype)
 | 
			
		||||
							
								
								
									
										132
									
								
								finetune/lora/v6/fla/ops/based/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								finetune/lora/v6/fla/ops/based/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,132 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
 | 
			
		||||
from fla.ops.based.chunk_fuse import fused_chunk_based
 | 
			
		||||
from fla.ops.based.parallel import parallel_based
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def naive_parallel_based(q, k, v, use_scale=True, use_norm=True):
 | 
			
		||||
    if use_scale:
 | 
			
		||||
        q = q * (q.shape[-1] ** -0.5)
 | 
			
		||||
    attn = q @ k.transpose(-2, -1)
 | 
			
		||||
    attn = 1 + attn + 1/2 * (attn ** 2)
 | 
			
		||||
    attn.masked_fill_(~torch.tril(torch.ones(
 | 
			
		||||
        q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
 | 
			
		||||
    o = attn @ v
 | 
			
		||||
    if use_norm:
 | 
			
		||||
        z = attn.sum(-1)
 | 
			
		||||
        return o / (z[..., None] + 1e-6)
 | 
			
		||||
    else:
 | 
			
		||||
        return o
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def naive_chunk_based(q, k, v, chunk_size=256):
 | 
			
		||||
    q = q * (q.shape[-1] ** -0.5)
 | 
			
		||||
 | 
			
		||||
    # compute normalizer.
 | 
			
		||||
    k_cumsum = torch.cumsum(k, dim=-2)
 | 
			
		||||
    kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
 | 
			
		||||
    # first
 | 
			
		||||
    z = (q * k_cumsum).sum(-1)
 | 
			
		||||
    # second order
 | 
			
		||||
    z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
 | 
			
		||||
    # zero-th order
 | 
			
		||||
    z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
 | 
			
		||||
 | 
			
		||||
    # compute o
 | 
			
		||||
    # constant term
 | 
			
		||||
    _o = v.cumsum(-2)
 | 
			
		||||
 | 
			
		||||
    q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
 | 
			
		||||
 | 
			
		||||
    k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
 | 
			
		||||
    v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
 | 
			
		||||
 | 
			
		||||
    intra_chunk_attn = q @ k.transpose(-2, -1)
 | 
			
		||||
    intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
 | 
			
		||||
    intra_chunk_attn.masked_fill_(
 | 
			
		||||
        ~torch.tril(
 | 
			
		||||
            torch.ones(chunk_size, chunk_size,
 | 
			
		||||
                       dtype=torch.bool, device=q.device),
 | 
			
		||||
        ), 0)
 | 
			
		||||
    o = intra_chunk_attn @ v
 | 
			
		||||
 | 
			
		||||
    # quadractic term
 | 
			
		||||
    kv = torch.einsum(
 | 
			
		||||
        'b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
 | 
			
		||||
    kv = kv.cumsum(2)
 | 
			
		||||
    kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
 | 
			
		||||
 | 
			
		||||
    o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
 | 
			
		||||
 | 
			
		||||
    # linear term
 | 
			
		||||
    kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
 | 
			
		||||
    kv = kv.cumsum(2)
 | 
			
		||||
    kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
 | 
			
		||||
    o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
 | 
			
		||||
 | 
			
		||||
    o = rearrange(o, 'b h n c d -> b h (n c) d')
 | 
			
		||||
    o = o + _o
 | 
			
		||||
    return o / (z[..., None] + 1e-6)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    B = 4
 | 
			
		||||
    H = 4
 | 
			
		||||
    L = 128
 | 
			
		||||
    # D = 15
 | 
			
		||||
    dtype = torch.float32
 | 
			
		||||
    q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
 | 
			
		||||
    k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
 | 
			
		||||
    v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True)
 | 
			
		||||
 | 
			
		||||
    do = torch.randn_like(v).cuda()
 | 
			
		||||
    ref = naive_parallel_based(q, k, v, True, True)
 | 
			
		||||
    ref.backward(do, retain_graph=True)
 | 
			
		||||
    ref_dq, q.grad = q.grad.clone(), None
 | 
			
		||||
    ref_dk, k.grad = k.grad.clone(), None
 | 
			
		||||
    ref_dv, v.grad = v.grad.clone(), None
 | 
			
		||||
 | 
			
		||||
    # tri = naive_chunk_based(q, k, v)
 | 
			
		||||
    # tri.backward(do, retain_graph=True)
 | 
			
		||||
    # tri_dq, q.grad = q.grad.clone(), None
 | 
			
		||||
    # tri_dk, k.grad = k.grad.clone(), None
 | 
			
		||||
    # tri_dv, v.grad = v.grad.clone(), None
 | 
			
		||||
 | 
			
		||||
    # assert ref.allclose(tri, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
 | 
			
		||||
 | 
			
		||||
    tri = fused_chunk_based(q, k, v, True, True)
 | 
			
		||||
    tri.backward(do, retain_graph=True)
 | 
			
		||||
    tri_dq, q.grad = q.grad.clone(), None
 | 
			
		||||
    tri_dk, k.grad = k.grad.clone(), None
 | 
			
		||||
    tri_dv, v.grad = v.grad.clone(), None
 | 
			
		||||
    print((ref-tri).abs().max())
 | 
			
		||||
    print((ref_dq-tri_dq).abs().max())
 | 
			
		||||
    print((ref_dk-tri_dk).abs().max())
 | 
			
		||||
    print((ref_dv-tri_dv).abs().max())
 | 
			
		||||
 | 
			
		||||
    # assert ref.allclose(tri, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
 | 
			
		||||
 | 
			
		||||
    tri = parallel_based(q, k, v, True, True)
 | 
			
		||||
    tri.backward(do, retain_graph=True)
 | 
			
		||||
    tri_dq, q.grad = q.grad.clone(), None
 | 
			
		||||
    tri_dk, k.grad = k.grad.clone(), None
 | 
			
		||||
    tri_dv, v.grad = v.grad.clone(), None
 | 
			
		||||
 | 
			
		||||
    print((ref-tri).abs().max())
 | 
			
		||||
    print((ref_dq-tri_dq).abs().max())
 | 
			
		||||
    print((ref_dk-tri_dk).abs().max())
 | 
			
		||||
    print((ref_dv-tri_dv).abs().max())
 | 
			
		||||
 | 
			
		||||
    # assert ref.allclose(tri, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
 | 
			
		||||
    # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
 | 
			
		||||
							
								
								
									
										388
									
								
								finetune/lora/v6/fla/ops/based/parallel.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										388
									
								
								finetune/lora/v6/fla/ops/based/parallel.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,388 @@
 | 
			
		||||
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
# Based: An Educational and Effective Sequence Mixer
 | 
			
		||||
# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def parallel_based_fwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_V]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V]
 | 
			
		||||
    o,  # output [B, H, L, D_head_V]
 | 
			
		||||
    z,  # normalizer [B, H, L]
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
    B,  # batch size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BTL: tl.constexpr,  # BLOCK SIZE along the sequence dimension for Q
 | 
			
		||||
    BTS: tl.constexpr,  # BLOCK SIZE along the sequence dimension for K/V
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
):
 | 
			
		||||
    # i_c: chunk index. used for sequence parallelism
 | 
			
		||||
    i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    NV = tl.cdiv(DV, BV)
 | 
			
		||||
    i_k = i_kv // (NV)
 | 
			
		||||
    i_v = i_kv % (NV)
 | 
			
		||||
 | 
			
		||||
    p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
 | 
			
		||||
                            (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
 | 
			
		||||
                            (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
 | 
			
		||||
                            (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))
 | 
			
		||||
 | 
			
		||||
    # [BQ, BD] block Q, in the shared memory throughout the whole kernel
 | 
			
		||||
    b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
    b_q = (b_q * scale).to(b_q.dtype)
 | 
			
		||||
    b_o = tl.zeros([BTL, BV], dtype=tl.float32)
 | 
			
		||||
    b_z = tl.zeros([BTL], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    # Q block and K block have no overlap
 | 
			
		||||
    # no need for mask, thereby saving flops
 | 
			
		||||
    for _ in range(0, i_c * BTL, BTS):
 | 
			
		||||
        # [BK, BTS]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        # [BTS, BV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BTL, BTS]
 | 
			
		||||
        b_s = tl.dot(b_q, (b_k), allow_tf32=False)
 | 
			
		||||
        b_s = 1 + b_s + 0.5 * b_s * b_s
 | 
			
		||||
        b_z += tl.sum(b_s, axis=1)
 | 
			
		||||
 | 
			
		||||
        # [BQ, BD]
 | 
			
		||||
        b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
 | 
			
		||||
        p_k = tl.advance(p_k, (0, BTS))
 | 
			
		||||
        p_v = tl.advance(p_v, (BTS, 0))
 | 
			
		||||
 | 
			
		||||
    # # rescale interchunk output
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
    o_q = tl.arange(0, BTL)
 | 
			
		||||
    # # sync threads, easy for compiler to optimize
 | 
			
		||||
    # tl.debug_barrier()
 | 
			
		||||
 | 
			
		||||
    o_k = tl.arange(0, BTS)
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
 | 
			
		||||
                            (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
 | 
			
		||||
                            (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
 | 
			
		||||
    # Q block and K block have overlap. masks required
 | 
			
		||||
    for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
 | 
			
		||||
        # [BK, BTS]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BTS, BV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BTL, BTS]
 | 
			
		||||
        m_s = o_q[:, None] >= o_k[None, :]
 | 
			
		||||
        b_s = tl.dot(b_q, b_k, allow_tf32=False)
 | 
			
		||||
        b_s = 1 + b_s + 0.5 * b_s * b_s
 | 
			
		||||
        b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
        b_z += tl.sum(b_s, axis=1)
 | 
			
		||||
        # [BTL, BV]
 | 
			
		||||
        b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
        p_k = tl.advance(p_k, (0, BTS))
 | 
			
		||||
        p_v = tl.advance(p_v, (BTS, 0))
 | 
			
		||||
        o_k += BTS
 | 
			
		||||
 | 
			
		||||
    p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),
 | 
			
		||||
                            (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
 | 
			
		||||
    p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
 | 
			
		||||
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    tl.store(p_z, b_z.to(p_z.dtype.element_ty),
 | 
			
		||||
             mask=((i_c * BTL + tl.arange(0, BTL)) < T))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _parallel_based_bwd_dq(
 | 
			
		||||
    i_bh, i_c, i_k, i_v, i_h,
 | 
			
		||||
    q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
 | 
			
		||||
    s_vo_t, s_vo_d, B, H, T, scale,
 | 
			
		||||
    BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
 | 
			
		||||
    DK: tl.constexpr,  DV: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
 | 
			
		||||
                             (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
 | 
			
		||||
    p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK),
 | 
			
		||||
                            (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
 | 
			
		||||
    b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
    b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
 | 
			
		||||
    b_q = (b_q * scale).to(b_q.dtype)
 | 
			
		||||
    b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
 | 
			
		||||
                            (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
 | 
			
		||||
                            (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))
 | 
			
		||||
    p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
 | 
			
		||||
    b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
 | 
			
		||||
 | 
			
		||||
    for _ in range(0, i_c * BTL, BTS):
 | 
			
		||||
        # [BTS, BK]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BV, BTS]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BTL, BTS]
 | 
			
		||||
        b_ds = tl.dot(b_do, b_v, allow_tf32=False)
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_ds += b_dz[:, None]
 | 
			
		||||
        else:
 | 
			
		||||
            b_ds = b_ds
 | 
			
		||||
        b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
 | 
			
		||||
        # [BQ, BD]
 | 
			
		||||
        b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)
 | 
			
		||||
        p_k = tl.advance(p_k, (BTS, 0))
 | 
			
		||||
        p_v = tl.advance(p_v, (0, BTS))
 | 
			
		||||
 | 
			
		||||
    b_dq *= scale
 | 
			
		||||
    o_q = tl.arange(0, BTL)
 | 
			
		||||
    o_k = tl.arange(0, BTS)
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
 | 
			
		||||
                            (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
 | 
			
		||||
                            (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
 | 
			
		||||
    # Q block and K block have overlap. masks required
 | 
			
		||||
    for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
 | 
			
		||||
        # [BTS, BK]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BV, BTS]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BTL, BTS]
 | 
			
		||||
        m_s = o_q[:, None] >= o_k[None, :]
 | 
			
		||||
        b_ds = tl.dot(b_do, b_v, allow_tf32=False)
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_ds += b_dz[:, None]
 | 
			
		||||
        else:
 | 
			
		||||
            b_ds = b_ds
 | 
			
		||||
        b_ds = tl.where(m_s, b_ds, 0) * scale
 | 
			
		||||
        b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
 | 
			
		||||
        b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
        # [BTL, BK]
 | 
			
		||||
        b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype),
 | 
			
		||||
                       b_k, allow_tf32=False)
 | 
			
		||||
        p_k = tl.advance(p_k, (BTS, 0))
 | 
			
		||||
        p_v = tl.advance(p_v, (0, BTS))
 | 
			
		||||
        o_k += BTS
 | 
			
		||||
    p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),
 | 
			
		||||
                             (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
 | 
			
		||||
    tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _parallel_based_bwd_dkv(
 | 
			
		||||
    i_bh, i_c, i_k, i_v, i_h,
 | 
			
		||||
    q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
 | 
			
		||||
    s_vo_t, s_vo_d, B, H, T, scale,
 | 
			
		||||
    BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
 | 
			
		||||
    DK: tl.constexpr,  DV: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    # compute dk dv
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
 | 
			
		||||
                            (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
 | 
			
		||||
                            (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
 | 
			
		||||
    b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(
 | 
			
		||||
        p_v, boundary_check=(0, 1))
 | 
			
		||||
    b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
 | 
			
		||||
        [BTL, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
 | 
			
		||||
        p_q = tl.make_block_ptr(
 | 
			
		||||
            q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(
 | 
			
		||||
            do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
 | 
			
		||||
        p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))  # [BK, BTS]
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)  # [BV, BTS]
 | 
			
		||||
        b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
 | 
			
		||||
        b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \
 | 
			
		||||
            scale  # [BTL, BTS]
 | 
			
		||||
        b_s2 = 1 + b_s + 0.5 * b_s * b_s
 | 
			
		||||
        b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
 | 
			
		||||
        b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_ds += b_dz[None, :] * scale
 | 
			
		||||
        else:
 | 
			
		||||
            b_ds = b_ds
 | 
			
		||||
        b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),
 | 
			
		||||
                       tl.trans(b_q), allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
    o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
 | 
			
		||||
    for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
 | 
			
		||||
        p_q = tl.make_block_ptr(
 | 
			
		||||
            q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(
 | 
			
		||||
            do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
 | 
			
		||||
        p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))  # [BD, BQ]
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
 | 
			
		||||
        b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
 | 
			
		||||
        # [BK, BQ]
 | 
			
		||||
        m_s = o_k[:, None] <= o_q[None, :]
 | 
			
		||||
        b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
 | 
			
		||||
        b_s2 = 1 + b_s + 0.5 * b_s * b_s
 | 
			
		||||
        b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
        b_s2 = tl.where(m_s, b_s2, 0)
 | 
			
		||||
 | 
			
		||||
        b_ds = tl.dot(b_v, b_do, allow_tf32=False)
 | 
			
		||||
        if i_v == 0:
 | 
			
		||||
            b_ds += b_dz[None, :]
 | 
			
		||||
        else:
 | 
			
		||||
            b_ds = b_ds
 | 
			
		||||
        b_ds = tl.where(m_s, b_ds, 0) * scale
 | 
			
		||||
        # [BK, BD]
 | 
			
		||||
        b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
 | 
			
		||||
        b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),
 | 
			
		||||
                       tl.trans(b_q), allow_tf32=False)
 | 
			
		||||
        o_q += BTS
 | 
			
		||||
 | 
			
		||||
    p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,
 | 
			
		||||
                             (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
 | 
			
		||||
    p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,
 | 
			
		||||
                             (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
 | 
			
		||||
    tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def parallel_based_bwd_kernel(
 | 
			
		||||
    q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
 | 
			
		||||
    s_vo_t, s_vo_d, B, H, T, scale,
 | 
			
		||||
    BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
 | 
			
		||||
    DK: tl.constexpr,  DV: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    NV = tl.cdiv(DV, BV)
 | 
			
		||||
    i_k = i_kv // (NV)
 | 
			
		||||
    i_v = i_kv % (NV)
 | 
			
		||||
    i_h = i_bh % H
 | 
			
		||||
    _parallel_based_bwd_dq(
 | 
			
		||||
        i_bh, i_c, i_k, i_v, i_h,
 | 
			
		||||
        q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
 | 
			
		||||
        s_vo_t, s_vo_d, B, H, T, scale,  BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV
 | 
			
		||||
    )
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
    _parallel_based_bwd_dkv(
 | 
			
		||||
        i_bh, i_c, i_k, i_v, i_h,
 | 
			
		||||
        q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
 | 
			
		||||
        s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParallelBasedFunction(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, q, k, v, scale):
 | 
			
		||||
        BTL, BTS = 128, 32
 | 
			
		||||
        assert BTL % BTS == 0
 | 
			
		||||
        # assert q.shape[-1] % 16 == 0
 | 
			
		||||
        BK = min(128, triton.next_power_of_2(k.shape[-1]))
 | 
			
		||||
        BV = min(128, triton.next_power_of_2(v.shape[-1]))
 | 
			
		||||
        BK, BV = max(BK, 16), max(BV, 16)
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
        num_stages = 2
 | 
			
		||||
        num_warps = 4
 | 
			
		||||
        NK = triton.cdiv(d_head_qk, BK)
 | 
			
		||||
        NV = triton.cdiv(d_head_v, BV)
 | 
			
		||||
        grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
 | 
			
		||||
 | 
			
		||||
        assert NK == 1, "will encounter some synchronization issue if not."
 | 
			
		||||
 | 
			
		||||
        o = torch.empty(NK, batch_size, n_heads, seq_len,
 | 
			
		||||
                        d_head_v, device=q.device)
 | 
			
		||||
        z = torch.empty(NK, batch_size, n_heads, seq_len,
 | 
			
		||||
                        device=q.device)
 | 
			
		||||
        parallel_based_fwd_kernel[grid](
 | 
			
		||||
            q, k, v, o, z,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
        ctx.save_for_backward(q, k, v)
 | 
			
		||||
        ctx.scale = scale
 | 
			
		||||
        return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, do, dz):
 | 
			
		||||
        q, k, v = ctx.saved_tensors
 | 
			
		||||
        scale = ctx.scale
 | 
			
		||||
        BTL, BTS = 64, 32
 | 
			
		||||
        assert BTL % BTS == 0
 | 
			
		||||
        BK = min(128, triton.next_power_of_2(k.shape[-1]))
 | 
			
		||||
        BV = min(128, triton.next_power_of_2(v.shape[-1]))
 | 
			
		||||
        BK, BV = max(BK, 16), max(BV, 16)
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
        num_stages = 2
 | 
			
		||||
        num_warps = 4
 | 
			
		||||
        NK = triton.cdiv(d_head_qk, BK)
 | 
			
		||||
        NV = triton.cdiv(d_head_v, BV)
 | 
			
		||||
        grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
 | 
			
		||||
 | 
			
		||||
        assert NK == 1, "will encounter some synchronization issue if not"
 | 
			
		||||
 | 
			
		||||
        dq = torch.empty(NV, batch_size, n_heads, seq_len,
 | 
			
		||||
                         d_head_qk, dtype=q.dtype, device=q.device)
 | 
			
		||||
        dk = torch.empty(NV, batch_size, n_heads, seq_len,
 | 
			
		||||
                         d_head_qk, dtype=q.dtype, device=q.device)
 | 
			
		||||
        dv = torch.empty(NK, batch_size, n_heads, seq_len,
 | 
			
		||||
                         d_head_v, dtype=q.dtype, device=q.device)
 | 
			
		||||
 | 
			
		||||
        parallel_based_bwd_kernel[grid](
 | 
			
		||||
            q, k, v, do, dz, dq, dk, dv,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
triton_parallel_based = ParallelBasedFunction.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):
 | 
			
		||||
    assert q.shape[-1] <= 128, "only support feature dim up to 128"
 | 
			
		||||
    if use_scale:
 | 
			
		||||
        scale = q.shape[-1] ** -0.5
 | 
			
		||||
    else:
 | 
			
		||||
        scale = 1
 | 
			
		||||
    o, z = triton_parallel_based(q, k, v, scale)
 | 
			
		||||
    if return_both:
 | 
			
		||||
        return o, z
 | 
			
		||||
    if use_normalize:
 | 
			
		||||
        o = o / (z[..., None] + 1e-6)
 | 
			
		||||
    else:
 | 
			
		||||
        o = o
 | 
			
		||||
    return o.to(q.dtype)
 | 
			
		||||
							
								
								
									
										4
									
								
								finetune/lora/v6/fla/ops/delta_rule/README.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								finetune/lora/v6/fla/ops/delta_rule/README.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
- Delta Rule
 | 
			
		||||
 | 
			
		||||
The implementation of delta rule described in https://arxiv.org/abs/2102.11174 
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										11
									
								
								finetune/lora/v6/fla/ops/delta_rule/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								finetune/lora/v6/fla/ops/delta_rule/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .chunk_fuse import fused_chunk_delta_rule
 | 
			
		||||
from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule
 | 
			
		||||
from .chunk import chunk_delta_rule
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'fused_chunk_delta_rule',
 | 
			
		||||
    'fused_recurrent_linear_attn_delta_rule',
 | 
			
		||||
    'chunk_delta_rule'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										544
									
								
								finetune/lora/v6/fla/ops/delta_rule/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										544
									
								
								finetune/lora/v6/fla/ops/delta_rule/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,544 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from fla.ops.utils import contiguous
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
from fla.ops.delta_rule.wy_fast import fwd_recompute_w_u, fwd_prepare_wy_repr, bwd_prepare_wy_repr
 | 
			
		||||
from fla.ops.delta_rule.chunk_fuse import fused_chunk_delta_rule_fwd, fused_chunk_delta_rule_bwd
 | 
			
		||||
# from fla.ops.delta_rule.utils import bwd_prepare_wy_repr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"], 
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fwd_prepare_dv_kernel(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    do,
 | 
			
		||||
    dv,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    T,
 | 
			
		||||
    K,
 | 
			
		||||
    V,
 | 
			
		||||
    scale,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_t, i_bh = tl.program_id(0), tl.program_id(1)
 | 
			
		||||
    
 | 
			
		||||
    b_A = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i_k in range(tl.cdiv(K, BK)):
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1)) 
 | 
			
		||||
        b_q = (b_q * scale).to(b_k.dtype)
 | 
			
		||||
        b_A += tl.dot(b_k, b_q, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
    b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A , 0).to(do.dtype.element_ty)
 | 
			
		||||
 | 
			
		||||
    for i_v in range(tl.cdiv(V, BV)):
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
        p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        b_dv = tl.dot(b_A, b_do, allow_tf32=False)
 | 
			
		||||
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fwd_prepare_dv(q, k, do, BT):
 | 
			
		||||
    dv = torch.empty_like(do)
 | 
			
		||||
    B, H, T, K, V = *k.shape, do.shape[-1]
 | 
			
		||||
    NT = triton.cdiv(T, BT)
 | 
			
		||||
    BK = min(triton.next_power_of_2(K), 64)
 | 
			
		||||
    BV = min(triton.next_power_of_2(V), 64)
 | 
			
		||||
    fwd_prepare_dv_kernel[(NT, B*H)](
 | 
			
		||||
        q, k, do, dv,
 | 
			
		||||
        k.stride(1), k.stride(2), k.stride(3), 
 | 
			
		||||
        do.stride(1), do.stride(2), do.stride(3),
 | 
			
		||||
        T, K, V, K**-0.5, BT, BK, BV
 | 
			
		||||
    )
 | 
			
		||||
    return dv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"], 
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_delta_rule_fwd_kernel_h(
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    d, 
 | 
			
		||||
    v_new,
 | 
			
		||||
    h,
 | 
			
		||||
    initial_state,  # initial state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
    final_state,  # final state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BC: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    NT: tl.constexpr,
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    STORE_FINAL_STATE: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    # [BK, BV]
 | 
			
		||||
    b_h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i_t in range(NT):
 | 
			
		||||
        p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
        # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
 | 
			
		||||
        for i_c in range(tl.cdiv(BT, BC)):
 | 
			
		||||
            p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
 | 
			
		||||
            p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
            p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
 | 
			
		||||
            p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))   
 | 
			
		||||
            b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
            # [BT, BK]
 | 
			
		||||
            b_d = tl.load(p_d, boundary_check=(0, 1))
 | 
			
		||||
            # [BT, BV]
 | 
			
		||||
            b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
            b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
            # [BK, BV]
 | 
			
		||||
            tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
            b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
        b_h += b_h_cumsum      
 | 
			
		||||
        
 | 
			
		||||
    if STORE_FINAL_STATE:
 | 
			
		||||
        p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"], 
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_linear_attn_fwd_kernel_o(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    h,
 | 
			
		||||
    o,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    scale,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
    m_s = o_i[:, None] >= o_i[None, :]
 | 
			
		||||
 | 
			
		||||
    b_o = tl.zeros([BT, BV], dtype=tl.float32)
 | 
			
		||||
    b_s = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
    for i_k in range(tl.cdiv(K, BK)):
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1)) 
 | 
			
		||||
        b_q = (b_q * scale).to(b_q.dtype)
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_h = tl.load(p_h, boundary_check=(0, 1))
 | 
			
		||||
        b_o += tl.dot(b_q, b_h, allow_tf32=False)
 | 
			
		||||
        b_s += tl.dot(b_q, b_k, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
    b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
    b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) 
 | 
			
		||||
    p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"], 
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_delta_rule_bwd_kernel_dhu(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    d,
 | 
			
		||||
    do,
 | 
			
		||||
    dh,
 | 
			
		||||
    dv,
 | 
			
		||||
    dv2,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    scale,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BC: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    NT: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    # [BK, BV]
 | 
			
		||||
    b_dh = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
    for i_t in range(NT - 1, -1, -1):
 | 
			
		||||
        p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
        for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
 | 
			
		||||
            p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
 | 
			
		||||
            p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
            p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
 | 
			
		||||
            p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
 | 
			
		||||
            p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
 | 
			
		||||
            # [BK, BT]
 | 
			
		||||
            b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
            b_q = (b_q * scale).to(b_q.dtype)
 | 
			
		||||
            # [BT, BK]
 | 
			
		||||
            b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
            b_d = tl.load(p_d, boundary_check=(0, 1))        
 | 
			
		||||
            # [BT, V]
 | 
			
		||||
            b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
            # [BT, BT]
 | 
			
		||||
            # b_s = tl.dot(b_k, b_q, allow_tf32=False)
 | 
			
		||||
            # b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
            # b_dv = tl.dot(b_s.to(b_do.dtype), b_do, allow_tf32=False) + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
            b_dv = tl.load(p_dv, boundary_check=(0, 1))
 | 
			
		||||
            b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
            p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
 | 
			
		||||
            tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
            # [BK, BV]
 | 
			
		||||
            b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) 
 | 
			
		||||
            b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False)
 | 
			
		||||
        b_dh += b_dh_tmp
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"], 
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_delta_rule_bwd_kernel_dqkw(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    w, 
 | 
			
		||||
    h,
 | 
			
		||||
    do,
 | 
			
		||||
    dh,
 | 
			
		||||
    dq,
 | 
			
		||||
    dk,
 | 
			
		||||
    dv,
 | 
			
		||||
    dw,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    scale,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    NT: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    n_bh = tl.num_programs(2)
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
    
 | 
			
		||||
    p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
 | 
			
		||||
    b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
    b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
    b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
 | 
			
		||||
    b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
 | 
			
		||||
 | 
			
		||||
    b_dq = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
    b_dk = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
    b_dw = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
    b_ds = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
    for i_v in range(tl.cdiv(V, BV)):
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
        # [BV, BK]
 | 
			
		||||
        b_h = tl.load(p_h, boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_dh = tl.load(p_dh, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BT]
 | 
			
		||||
        b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
 | 
			
		||||
        b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
        b_dv = tl.load(p_dv, boundary_check=(0, 1))
 | 
			
		||||
        b_dw += tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
        
 | 
			
		||||
    # [BT, BT]
 | 
			
		||||
    b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
 | 
			
		||||
    # [BT, BK]
 | 
			
		||||
    b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
 | 
			
		||||
    b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
 | 
			
		||||
 | 
			
		||||
    p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state):
 | 
			
		||||
    B, H, T, K, V = *k.shape, u.shape[-1]
 | 
			
		||||
 | 
			
		||||
    BK = triton.next_power_of_2(K)
 | 
			
		||||
    assert BK <= 256, "current kernel does not support head dimension larger than 256."
 | 
			
		||||
    BV = 16 if BK > 128 else 32        
 | 
			
		||||
    BV = 64 if BK <= 64 else BV
 | 
			
		||||
    BC = 16 if BK > 128 else 32 
 | 
			
		||||
    BC = 64 if BK <= 64 else BC
 | 
			
		||||
    BC = min(BT, BC)
 | 
			
		||||
    NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
 | 
			
		||||
    assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
 | 
			
		||||
 | 
			
		||||
    h = k.new_empty(B, H, NT * K, V)
 | 
			
		||||
    grid = (NK, NV, B * H)
 | 
			
		||||
    v_new = torch.empty_like(u)
 | 
			
		||||
    chunk_delta_rule_fwd_kernel_h[grid](
 | 
			
		||||
        k, u, w, v_new, h, initial_state, final_state,
 | 
			
		||||
        k.stride(1), k.stride(2), k.stride(3),
 | 
			
		||||
        u.stride(1), u.stride(2), u.stride(3),
 | 
			
		||||
        h.stride(1), h.stride(2),
 | 
			
		||||
        H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
        USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
        STORE_FINAL_STATE=final_state is not None,
 | 
			
		||||
        )
 | 
			
		||||
    return h, v_new
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
def chunk_bwd_dhu_fn(q, k, w, do, dv, BT):
 | 
			
		||||
    B, H, T, K, V = *q.shape, do.shape[-1]
 | 
			
		||||
 | 
			
		||||
    BK = triton.next_power_of_2(K)
 | 
			
		||||
    assert BK <= 256, "current kernel does not support head dimension being larger than 256."
 | 
			
		||||
    BV = 16 if BK > 128 else 32        
 | 
			
		||||
    BV = 64 if BK <= 64 else BV
 | 
			
		||||
    BC = 16 if BK > 128 else 32 
 | 
			
		||||
    BC = 64 if BK <= 64 else BC
 | 
			
		||||
    BC = min(BT, BC)
 | 
			
		||||
    NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
 | 
			
		||||
    assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
 | 
			
		||||
 | 
			
		||||
    dh = q.new_empty(B, H, NT * K, V)
 | 
			
		||||
    # dv_new = torch.empty_like(do)
 | 
			
		||||
    grid = (NK, NV, B * H)
 | 
			
		||||
    dv2 = torch.empty_like(dv)
 | 
			
		||||
    chunk_delta_rule_bwd_kernel_dhu[grid](
 | 
			
		||||
        q, k, w, do, dh, dv, dv2,
 | 
			
		||||
        q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
        do.stride(1), do.stride(2), do.stride(3),
 | 
			
		||||
        dh.stride(1), dh.stride(2),
 | 
			
		||||
        K**-0.5,
 | 
			
		||||
        H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
    )
 | 
			
		||||
    return dh, dv2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chunk_fwd_o_fn(q, k, v_new, h, BT):
 | 
			
		||||
    B, H, T, K, V = *q.shape, v_new.shape[-1]
 | 
			
		||||
 | 
			
		||||
    BK = triton.next_power_of_2(K)
 | 
			
		||||
    o = torch.empty_like(v_new)
 | 
			
		||||
    BK = min(triton.next_power_of_2(K), 64)
 | 
			
		||||
    BV = min(triton.next_power_of_2(K), 64)
 | 
			
		||||
    NV = triton.cdiv(V, BV)
 | 
			
		||||
    NT = triton.cdiv(T, BT)
 | 
			
		||||
    grid = (NV, NT, B * H)
 | 
			
		||||
    chunk_linear_attn_fwd_kernel_o[grid](
 | 
			
		||||
            q, k, v_new, h, o,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v_new.stride(1), v_new.stride(2), v_new.stride(3),
 | 
			
		||||
            h.stride(1), h.stride(2),
 | 
			
		||||
            scale=K**-0.5,
 | 
			
		||||
            H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
 | 
			
		||||
    )
 | 
			
		||||
    return o
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT):
 | 
			
		||||
    B, H, T, K, V = *q.shape, v_new.shape[-1]
 | 
			
		||||
 | 
			
		||||
    BK = triton.next_power_of_2(K)
 | 
			
		||||
    BK = min(triton.next_power_of_2(K), 64)
 | 
			
		||||
    BV = min(triton.next_power_of_2(V), 64)
 | 
			
		||||
    NV = triton.cdiv(V, BV)
 | 
			
		||||
    NT = triton.cdiv(T, BT)
 | 
			
		||||
    grid = (NV, NT, B * H)
 | 
			
		||||
    dq = torch.empty_like(q)
 | 
			
		||||
    dk = torch.empty_like(k) 
 | 
			
		||||
    dw = torch.empty_like(w) 
 | 
			
		||||
    chunk_delta_rule_bwd_kernel_dqkw[grid](
 | 
			
		||||
        q, k, v_new, w, h, do, dh, dq, dk, du, dw,
 | 
			
		||||
        q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
        v_new.stride(1), v_new.stride(2), v_new.stride(3),
 | 
			
		||||
        dh.stride(1), dh.stride(2),
 | 
			
		||||
        scale = K ** -0.5,
 | 
			
		||||
        H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
    )
 | 
			
		||||
    return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChunkDeltaRuleFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):        
 | 
			
		||||
        ### obtain WY representation. u is actually the new v.
 | 
			
		||||
        w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
 | 
			
		||||
        # ### forward_h 
 | 
			
		||||
        final_state = None
 | 
			
		||||
        if output_final_state:
 | 
			
		||||
            final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
 | 
			
		||||
        h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)        
 | 
			
		||||
        ## obtain output 
 | 
			
		||||
        o = chunk_fwd_o_fn(q, k, v_new, h, BT)
 | 
			
		||||
        # save memory
 | 
			
		||||
        if checkpoint_level == 1:
 | 
			
		||||
            h, v_new = None, None
 | 
			
		||||
        ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state)
 | 
			
		||||
        ctx.BT = BT
 | 
			
		||||
        return o.to(q.dtype), final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, do, d_ht=None):
 | 
			
		||||
        q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
 | 
			
		||||
        scale = q.shape[-1] ** -0.5
 | 
			
		||||
        BT = ctx.BT
 | 
			
		||||
        w, u = fwd_recompute_w_u(k, v, beta, A, BT)
 | 
			
		||||
        # checkpont_level=1, recomputation.
 | 
			
		||||
        if h is None:
 | 
			
		||||
            h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None)
 | 
			
		||||
        dv = fwd_prepare_dv(q, k, do, BT)
 | 
			
		||||
        dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)
 | 
			
		||||
        dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)
 | 
			
		||||
        dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT)
 | 
			
		||||
        dk.add_(dk2)
 | 
			
		||||
        return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None
 | 
			
		||||
 | 
			
		||||
def chunk_delta_rule(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    beta: torch.Tensor,
 | 
			
		||||
    BT: int,
 | 
			
		||||
    initial_state: torch.Tensor = None,
 | 
			
		||||
    output_final_state: bool = False
 | 
			
		||||
):
 | 
			
		||||
    assert q.dtype == k.dtype == v.dtype
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = initial_state.detach()
 | 
			
		||||
    o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT,  initial_state, output_final_state)
 | 
			
		||||
    return o, final_state
 | 
			
		||||
							
								
								
									
										419
									
								
								finetune/lora/v6/fla/ops/delta_rule/chunk_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										419
									
								
								finetune/lora/v6/fla/ops/delta_rule/chunk_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,419 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from packaging import version
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
 | 
			
		||||
from fla.ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# on-the-fly computation without materializing hidden statets into HBMs
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8)
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK"],
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_chunk_delta_rule_fwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_K]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V]
 | 
			
		||||
    v_new,
 | 
			
		||||
    d,  # decay [B, H, L, D_head_K]
 | 
			
		||||
    o,  # output [B, H, L, D_head_V]
 | 
			
		||||
    initial_state,  # initial state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
    final_state,  # final state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
    B,  # batch size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BT: tl.constexpr,  # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    STORE_FINAL_STATE: tl.constexpr,
 | 
			
		||||
    CHECK: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    # indices
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    # [BT, BT]
 | 
			
		||||
    m_s = o_i[:, None] >= o_i[None, :]
 | 
			
		||||
    # [BK, BV]
 | 
			
		||||
    b_h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    # make block pointers
 | 
			
		||||
    p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
 | 
			
		||||
    p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i in range(0, tl.cdiv(T, BT)):
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_d = tl.load(p_d, boundary_check=(0, 1))
 | 
			
		||||
        b_q = (b_q * scale).to(b_k.dtype)
 | 
			
		||||
 | 
			
		||||
        # [BT, BT]
 | 
			
		||||
        b_s = tl.dot(b_q, b_k, allow_tf32=False)
 | 
			
		||||
        b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)
 | 
			
		||||
        b_v = b_v - b_v_prime
 | 
			
		||||
        tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)
 | 
			
		||||
        if CHECK and i == 0:
 | 
			
		||||
            b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
 | 
			
		||||
            b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
        else:
 | 
			
		||||
            b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
 | 
			
		||||
            b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
        tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        p_q = tl.advance(p_q, (BT, 0))
 | 
			
		||||
        p_k = tl.advance(p_k, (0, BT))
 | 
			
		||||
        p_v = tl.advance(p_v, (BT, 0))
 | 
			
		||||
        p_v_new = tl.advance(p_v_new, (BT, 0))
 | 
			
		||||
        p_o = tl.advance(p_o, (BT, 0))
 | 
			
		||||
        p_d = tl.advance(p_d, (BT, 0))
 | 
			
		||||
 | 
			
		||||
    if STORE_FINAL_STATE:
 | 
			
		||||
        p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"],
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_chunk_delta_rule_bwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    # NV: number of split in the V dimension. NK: number of split in the K dimension
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_V]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V]
 | 
			
		||||
    d,  # decay [B, H, L, D_head_K]
 | 
			
		||||
    do,  # gradient of output [B, H, L, D_head_V]
 | 
			
		||||
    dq,  # gradient of query [NV, B, H, L, D_head_K]
 | 
			
		||||
    dk,  # gradient of key [NV, B, H, L, D_head_K]
 | 
			
		||||
    dv,  # gradient of value [NK, B, H, L, D_head_V]
 | 
			
		||||
    dd,  # gradient of decay [NV, B, H, L, D_head_K]
 | 
			
		||||
    initial_state,  # initial state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
    B,  # batch_size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BT: tl.constexpr,  # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    CHECK: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    # first reverse
 | 
			
		||||
    # [BK, BV]
 | 
			
		||||
    b_dh = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
    m_s = o_i[:, None] <= o_i[None, :]
 | 
			
		||||
    for i in range(1, tl.cdiv(T, BT) + 1):
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
 | 
			
		||||
        # [DK, BT]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_q = (b_q * scale).to(b_q.dtype)
 | 
			
		||||
        # [BT, DK]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, DV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        # [BT, BT]
 | 
			
		||||
        b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
 | 
			
		||||
        b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
 | 
			
		||||
        # [BT, BT]
 | 
			
		||||
        b_s = tl.dot(b_k, b_q, allow_tf32=False)
 | 
			
		||||
        b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
 | 
			
		||||
        # [BT, DK]
 | 
			
		||||
        b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
 | 
			
		||||
        # [BT, DV]
 | 
			
		||||
        b_dv = tl.dot(b_s, b_do, allow_tf32=False)
 | 
			
		||||
        b_d = tl.load(p_d, boundary_check=(0, 1))
 | 
			
		||||
        if CHECK and i == 1:
 | 
			
		||||
            b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype),  allow_tf32=False)
 | 
			
		||||
            b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
            b_dh += tl.dot(b_q, b_do, allow_tf32=False)
 | 
			
		||||
            b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
 | 
			
		||||
        else:
 | 
			
		||||
            b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype),  allow_tf32=False)
 | 
			
		||||
            b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
            b_dh += tl.dot(b_q, b_do, allow_tf32=False)
 | 
			
		||||
            b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
        tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    # sync threads
 | 
			
		||||
    b_h = None
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
    m_s = o_i[:, None] >= o_i[None, :]
 | 
			
		||||
    # [BV, BK]
 | 
			
		||||
    b_h = tl.zeros([BV, BK], dtype=tl.float32)
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
 | 
			
		||||
        b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
    NT = tl.cdiv(T, BT)
 | 
			
		||||
    for i in range(0, NT):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
 | 
			
		||||
 | 
			
		||||
        # [BT, DK]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [DV, BT]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, DV]
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        # [BT, BT]
 | 
			
		||||
        b_ds = tl.dot(b_do, b_v, allow_tf32=False)
 | 
			
		||||
        b_ds = tl.where(m_s, b_ds, 0)
 | 
			
		||||
        # [BT, DK]
 | 
			
		||||
        b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
 | 
			
		||||
        # [DV, DK]
 | 
			
		||||
        if CHECK and i == 0:
 | 
			
		||||
            b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
 | 
			
		||||
            b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
 | 
			
		||||
        else:
 | 
			
		||||
            b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
 | 
			
		||||
            b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
 | 
			
		||||
        b_dq *= scale
 | 
			
		||||
        tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        if i < (NT - 1):
 | 
			
		||||
            p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
            b_dv = tl.load(p_dv, boundary_check=(0, 1))
 | 
			
		||||
            b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
 | 
			
		||||
            p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
 | 
			
		||||
                                     ((i+1) * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
            tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):
 | 
			
		||||
    batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
    d_head_v = v.shape[-1]
 | 
			
		||||
    scale = d_head_qk ** -0.5
 | 
			
		||||
    BT = BT
 | 
			
		||||
    # ctx.BT = BT
 | 
			
		||||
    BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
 | 
			
		||||
    NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
    assert NK == 1, 'NK should be 1'
 | 
			
		||||
    o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)
 | 
			
		||||
    if output_final_state:
 | 
			
		||||
        final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
 | 
			
		||||
    else:
 | 
			
		||||
        final_state = None
 | 
			
		||||
    CHECK = True
 | 
			
		||||
    # if version.parse(triton.__version__) < version.parse('2.2.0'):
 | 
			
		||||
    #     import warnings
 | 
			
		||||
    #     warnings.warn(
 | 
			
		||||
    #         "Triton<2.2.0 detected for running this kernel, "
 | 
			
		||||
    #         "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
 | 
			
		||||
    #         "that lead to significant precision loss. "
 | 
			
		||||
    #         "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
 | 
			
		||||
    #         "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
 | 
			
		||||
    #     )
 | 
			
		||||
    #     CHECK = True
 | 
			
		||||
    grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
    v_new = torch.empty_like(v)
 | 
			
		||||
    fused_chunk_delta_rule_fwd_kernel[grid](
 | 
			
		||||
        q, k, v, v_new, d, o, initial_state, final_state,
 | 
			
		||||
        q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
        v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
        batch_size, n_heads, seq_len, scale,
 | 
			
		||||
        BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
        USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
        STORE_FINAL_STATE=output_final_state,
 | 
			
		||||
        CHECK=CHECK,
 | 
			
		||||
    )
 | 
			
		||||
    return o, v_new, CHECK, final_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):
 | 
			
		||||
    batch_size, n_heads,  seq_len, d_head_qk = q.shape
 | 
			
		||||
    d_head_v = v.shape[-1]
 | 
			
		||||
    scale = d_head_qk ** -0.5
 | 
			
		||||
    BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
 | 
			
		||||
    NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
    assert NK == 1
 | 
			
		||||
    dq = q.new_empty(NV, batch_size, n_heads,  seq_len, d_head_qk)
 | 
			
		||||
    dk = q.new_empty(NV, batch_size, n_heads,  seq_len, d_head_qk)
 | 
			
		||||
    dd = q.new_empty(NV, batch_size, n_heads,  seq_len, d_head_qk)
 | 
			
		||||
    dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
 | 
			
		||||
    grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
    fused_chunk_delta_rule_bwd_kernel[grid](
 | 
			
		||||
        q, k, v, d, do, dq, dk, dv, dd, initial_state,
 | 
			
		||||
        q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
        v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
        batch_size, n_heads, seq_len, scale,
 | 
			
		||||
        BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
        USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
        CHECK=CHECK,
 | 
			
		||||
        # num_warps=num_warps,
 | 
			
		||||
        # num_stages=num_stages
 | 
			
		||||
    )
 | 
			
		||||
    dq = dq.sum(0)
 | 
			
		||||
    dk = dk.sum(0)
 | 
			
		||||
    dv = dv.sum(0)
 | 
			
		||||
    dd = dd.sum(0)
 | 
			
		||||
    dd[:, :, 0:BT] = 0
 | 
			
		||||
    return dq, dk, dv, dd
 | 
			
		||||
 | 
			
		||||
class FusedChunkDeltaRuleFunction(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0):
 | 
			
		||||
        # lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory.
 | 
			
		||||
        assert checkpoint_level in [0, 1]
 | 
			
		||||
        k_origin = k
 | 
			
		||||
        # k = _l2_norm_fwd(k_origin)
 | 
			
		||||
        k = k
 | 
			
		||||
        d, v_new = fwd_prepare_wy_repr(k, v, beta, BT)
 | 
			
		||||
        o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state)
 | 
			
		||||
        if checkpoint_level == 1:
 | 
			
		||||
            d, v_new = None, None
 | 
			
		||||
        ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state)
 | 
			
		||||
        ctx.CHECK = CHECK
 | 
			
		||||
        ctx.chunk_size = BT
 | 
			
		||||
        return o.to(q.dtype), final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, do, d_final_state=None):
 | 
			
		||||
        q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors
 | 
			
		||||
        chunk_size = ctx.chunk_size
 | 
			
		||||
        k = k_origin
 | 
			
		||||
        # k = _l2_norm_fwd(k_origin)
 | 
			
		||||
        if d is None:
 | 
			
		||||
            d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
 | 
			
		||||
        dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state)
 | 
			
		||||
        dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size)
 | 
			
		||||
        dk.add_(dk2)
 | 
			
		||||
        # dk = _l2_norm_bwd(k_origin, dk)
 | 
			
		||||
        return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fused_chunk_delta_rule(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    beta: torch.Tensor,
 | 
			
		||||
    BT: int,
 | 
			
		||||
    initial_state: torch.Tensor = None,
 | 
			
		||||
    output_final_state: bool = False,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = initial_state.detach()
 | 
			
		||||
    o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
 | 
			
		||||
    return o, final_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def delta_rule_recurrence(q, k, v, beta):
 | 
			
		||||
    b, h, l, d_k = q.shape
 | 
			
		||||
    d_v = v.shape[-1]
 | 
			
		||||
    o = torch.zeros_like(v)
 | 
			
		||||
    S = torch.zeros(b, h, d_k, d_v).to(v)
 | 
			
		||||
    q = q * (d_k ** -0.5)
 | 
			
		||||
    k = torch.nn.functional.normalize(k, p=2, dim=-1)
 | 
			
		||||
    for i in range(l):
 | 
			
		||||
        _k = k[:, :, i]
 | 
			
		||||
        _q = q[:, :, i]
 | 
			
		||||
        _v = v[:, :, i].clone()
 | 
			
		||||
        beta_i = beta[:, :, i]
 | 
			
		||||
        _v = _v - (S.clone() * _k[..., None]).sum(-2)
 | 
			
		||||
        _v = _v * beta_i[..., None]
 | 
			
		||||
        S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
 | 
			
		||||
        o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
 | 
			
		||||
    return o
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    import torch.nn.functional as F
 | 
			
		||||
    seq_len = 128
 | 
			
		||||
    b = 2
 | 
			
		||||
    h = 4
 | 
			
		||||
    q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
 | 
			
		||||
    k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
 | 
			
		||||
    v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1)
 | 
			
		||||
    beta = torch.rand(b, h, seq_len).sigmoid()
 | 
			
		||||
    q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta))
 | 
			
		||||
    do = torch.rand_like(v)
 | 
			
		||||
    o2 = delta_rule_recurrence(q, k, v.clone(), beta)
 | 
			
		||||
    o2.backward(do, retain_graph=True)
 | 
			
		||||
    q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad
 | 
			
		||||
    q.grad = k.grad = v.grad = beta.grad = None
 | 
			
		||||
    o, _ = fused_chunk_delta_rule(q, k, v, beta, 32)
 | 
			
		||||
    o.backward(do, retain_graph=True)
 | 
			
		||||
    q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad
 | 
			
		||||
    q.grad = k.grad = v.grad = beta.grad = None
 | 
			
		||||
    print((o - o2).abs().max())
 | 
			
		||||
    print((q_grad - q_grad2).abs().max())
 | 
			
		||||
    print((k_grad - k_grad2).abs().max())
 | 
			
		||||
    print((v_grad - v_grad2).abs().max())
 | 
			
		||||
    print((beta_grad - beta_grad2).abs().max())
 | 
			
		||||
							
								
								
									
										92
									
								
								finetune/lora/v6/fla/ops/delta_rule/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								finetune/lora/v6/fla/ops/delta_rule/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,92 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def delta_rule_recurrence(q, k, v, beta):
 | 
			
		||||
    b, h, l, d_k = q.shape
 | 
			
		||||
    d_v = v.shape[-1]
 | 
			
		||||
    o = torch.zeros_like(v)
 | 
			
		||||
    S = torch.zeros(b, h, d_k, d_v).to(v)
 | 
			
		||||
    q = q * (d_k ** -0.5)
 | 
			
		||||
    for i in range(l):
 | 
			
		||||
        _k = k[:, :, i]
 | 
			
		||||
        _q = q[:, :, i]
 | 
			
		||||
        _v = v[:, :, i].clone()
 | 
			
		||||
        beta_i = beta[:, :, i]
 | 
			
		||||
        _v = _v - (S.clone() * _k[..., None]).sum(-2)
 | 
			
		||||
        _v = _v * beta_i[..., None]
 | 
			
		||||
        S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
 | 
			
		||||
        o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
 | 
			
		||||
    return o
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
 | 
			
		||||
    b, h, l, d_k = q.shape
 | 
			
		||||
    d_v = v.shape[-1]
 | 
			
		||||
    q = q * (d_k ** -0.5)
 | 
			
		||||
    v = v * beta[..., None]
 | 
			
		||||
    k_beta = k * beta[..., None]
 | 
			
		||||
 | 
			
		||||
    assert l % chunk_size == 0
 | 
			
		||||
 | 
			
		||||
    # note that diagonal is masked.
 | 
			
		||||
    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
 | 
			
		||||
    q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
 | 
			
		||||
    attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
 | 
			
		||||
 | 
			
		||||
    for i in range(1, chunk_size):
 | 
			
		||||
        attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
 | 
			
		||||
 | 
			
		||||
    attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
 | 
			
		||||
    # u
 | 
			
		||||
    k_cumsum = attn @ v
 | 
			
		||||
    # w
 | 
			
		||||
    k_cumdecay = attn @ k_beta
 | 
			
		||||
 | 
			
		||||
    v = k_cumsum
 | 
			
		||||
    S = k.new_zeros(b, h, d_k, d_v)
 | 
			
		||||
    o = torch.zeros_like(v)
 | 
			
		||||
    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
 | 
			
		||||
    for i in range(0, l // chunk_size):
 | 
			
		||||
        q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i]
 | 
			
		||||
        attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
 | 
			
		||||
        v_prime = k_cumdecay[:, :, i] @ S
 | 
			
		||||
        v_new = v_i - v_prime
 | 
			
		||||
        o_inter = q_i @ S
 | 
			
		||||
        o[:, :, i] = o_inter + attn @ v_new
 | 
			
		||||
        # chunk state update
 | 
			
		||||
        S = S + k_i.transpose(-1, -2) @ v_new
 | 
			
		||||
 | 
			
		||||
    return rearrange(o, 'b h n c d -> b h (n c) d')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    B = 2
 | 
			
		||||
    H = 4
 | 
			
		||||
    L = 256
 | 
			
		||||
    DK = 128
 | 
			
		||||
    DV = 128
 | 
			
		||||
    q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True)
 | 
			
		||||
    k = (torch.randn(B, H, L, DK)).cuda()
 | 
			
		||||
    k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True)
 | 
			
		||||
    v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True)
 | 
			
		||||
    beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True)
 | 
			
		||||
 | 
			
		||||
    o = delta_rule_recurrence(q, k, v, beta)
 | 
			
		||||
    do = torch.randn(B, H, L, DV).cuda()
 | 
			
		||||
    o.backward(do, retain_graph=True)
 | 
			
		||||
    q_grad, q.grad = q.grad, None
 | 
			
		||||
    k_grad, k.grad = k.grad, None
 | 
			
		||||
    v_grad, v.grad = v.grad, None
 | 
			
		||||
    beta_grad, beta.grad = beta.grad, None
 | 
			
		||||
 | 
			
		||||
    o2 = delta_rule_chunkwise(q, k, v, beta)
 | 
			
		||||
    o2.backward(do)
 | 
			
		||||
    assert torch.allclose(o, o2, atol=1e-4), breakpoint()
 | 
			
		||||
    assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint()
 | 
			
		||||
    assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint()
 | 
			
		||||
    assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint()
 | 
			
		||||
    assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint()
 | 
			
		||||
    print("All passed!")
 | 
			
		||||
							
								
								
									
										312
									
								
								finetune/lora/v6/fla/ops/delta_rule/recurrent_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										312
									
								
								finetune/lora/v6/fla/ops/delta_rule/recurrent_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,312 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
# on-the-fly computation without materializing hidden statets into HBMs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_recurrent_fwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_V]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V].
 | 
			
		||||
    beta,  # beta [B, H, L]
 | 
			
		||||
    o,  # output [B, H, L, D_head_V]
 | 
			
		||||
    initial_state,
 | 
			
		||||
    final_state,  # final hidden state [B, H, D_head_K, D_head_V]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    B,  # batch size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,  # whether to use initial state
 | 
			
		||||
    STORE_FINAL_STATE: tl.constexpr,  # whether to store final state
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    # indices
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
 | 
			
		||||
    p_beta = beta + i_bh * T
 | 
			
		||||
    p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
 | 
			
		||||
 | 
			
		||||
    mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
 | 
			
		||||
    mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
 | 
			
		||||
    mask_kv = mask_bk[None, :] & mask_bv[:, None]
 | 
			
		||||
 | 
			
		||||
    h = tl.zeros([BV, BK], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_init_s = initial_state + i_bh * DK * DV + \
 | 
			
		||||
            (i_k * BK + tl.arange(0, BK)[None, :]) * \
 | 
			
		||||
            DV + (i_v * BV + tl.arange(0, BV)[:, None])
 | 
			
		||||
        h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    for _ in range(0, T):
 | 
			
		||||
        _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
        _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
 | 
			
		||||
        _v_minus = tl.sum(h * _k[None, :], axis=1)
 | 
			
		||||
        _v -= _v_minus
 | 
			
		||||
        _beta = tl.load(p_beta).to(tl.float32)
 | 
			
		||||
        # in-place overwrite
 | 
			
		||||
        tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv)
 | 
			
		||||
        _v *= _beta
 | 
			
		||||
        h += _k[None, :] * _v[:, None]
 | 
			
		||||
        _o = h * _q[None, :]
 | 
			
		||||
        _o = tl.sum(_o, axis=1)
 | 
			
		||||
        tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
 | 
			
		||||
 | 
			
		||||
        p_q += DK
 | 
			
		||||
        p_k += DK
 | 
			
		||||
        p_o += DV
 | 
			
		||||
        p_v += DV
 | 
			
		||||
        p_beta += 1
 | 
			
		||||
 | 
			
		||||
    if STORE_FINAL_STATE:
 | 
			
		||||
        p_final_s = final_state + i_bh * DK * DV + \
 | 
			
		||||
            (i_k * BK + tl.arange(0, BK)[None, :]) * \
 | 
			
		||||
            DV + (i_v * BV + tl.arange(0, BV)[:, None])
 | 
			
		||||
        tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_recurrent_bwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    # NV: number of split in the V dimension. NK: number of split in the K dimension
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_V]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V]
 | 
			
		||||
    beta,  # beta [B, H, L]
 | 
			
		||||
 | 
			
		||||
    do,  # gradient of output [B, H, L, D_head_V]
 | 
			
		||||
    dq,  # gradient of query [NV, B, H, L, D_head_K]
 | 
			
		||||
    dk,  # gradient of key [NV, B, H, L, D_head_K]
 | 
			
		||||
    dv,  # gradient of value [NK, B, H, L, D_head_V]
 | 
			
		||||
    dbeta,  # gradient of beta [B, H, L]
 | 
			
		||||
 | 
			
		||||
    # initial hidden state initialization [B, H, D_head_K, D_head_V]
 | 
			
		||||
    initial_state,
 | 
			
		||||
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    B,  # batch_size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,  # whether to use initial state
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    mask_bk = i_k * BK + tl.arange(0, BK) < DK
 | 
			
		||||
    mask_bv = i_v * BV + tl.arange(0, BV) < DV
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
 | 
			
		||||
    p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
 | 
			
		||||
    p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
 | 
			
		||||
    p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
 | 
			
		||||
    p_beta = beta + i_bh * T + T - 1
 | 
			
		||||
    p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1
 | 
			
		||||
 | 
			
		||||
    p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
 | 
			
		||||
        BK + tl.arange(0, BK) + (T - 1) * DK
 | 
			
		||||
    p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
 | 
			
		||||
        BV + tl.arange(0, BV) + (T - 1) * DV
 | 
			
		||||
    d_h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    for _ in range(T):
 | 
			
		||||
        _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
 | 
			
		||||
        _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
        _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        _beta = tl.load(p_beta).to(tl.float32)
 | 
			
		||||
        d_h += _q[:, None] * _do[None, :]
 | 
			
		||||
        d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1)
 | 
			
		||||
        d_v = tl.sum(d_h * _k[:, None], axis=0)
 | 
			
		||||
 | 
			
		||||
        d_beta = tl.sum(d_v * _v)
 | 
			
		||||
        d_v = d_v * _beta
 | 
			
		||||
 | 
			
		||||
        tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
 | 
			
		||||
        tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
 | 
			
		||||
        tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))
 | 
			
		||||
 | 
			
		||||
        d_h -= _k[:, None] * d_v[None, :]
 | 
			
		||||
 | 
			
		||||
        p_do -= DV
 | 
			
		||||
        p_q -= DK
 | 
			
		||||
        p_k -= DK
 | 
			
		||||
        p_v -= DV
 | 
			
		||||
        p_dk -= DK
 | 
			
		||||
        p_dv -= DV
 | 
			
		||||
        p_dbeta -= 1
 | 
			
		||||
        p_beta -= 1
 | 
			
		||||
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
 | 
			
		||||
    h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
 | 
			
		||||
    p_beta = beta + i_bh * T
 | 
			
		||||
    p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
 | 
			
		||||
    p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV
 | 
			
		||||
    p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        mask_kv = mask_bk[:, None] & mask_bv[None, :]
 | 
			
		||||
        p_init_s = initial_state + i_bh * DK * DV + \
 | 
			
		||||
            (i_k * BK + tl.arange(0, BK)[:, None]) * \
 | 
			
		||||
            DV + (i_v * BV + tl.arange(0, BV)[None, :])
 | 
			
		||||
        h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i in range(0, T):
 | 
			
		||||
        _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
        _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        _beta = tl.load(p_beta).to(tl.float32)
 | 
			
		||||
        _v *= _beta
 | 
			
		||||
 | 
			
		||||
        h += _k[:, None] * _v[None, :]
 | 
			
		||||
        _d_q = h * _do[None, :]
 | 
			
		||||
        d_q = tl.sum(_d_q, axis=1) * scale
 | 
			
		||||
        tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
 | 
			
		||||
 | 
			
		||||
        if i < T - 1:
 | 
			
		||||
            d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
            d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
            d_k -= tl.sum(d_v[None, :] * h, axis=1)
 | 
			
		||||
            tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
 | 
			
		||||
 | 
			
		||||
        p_k += DK
 | 
			
		||||
        p_do += DV
 | 
			
		||||
        p_v += DV
 | 
			
		||||
        p_dk += DK
 | 
			
		||||
        p_dv += DV
 | 
			
		||||
        p_dq += DK
 | 
			
		||||
        p_beta += 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedRecurrentFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False):
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
 | 
			
		||||
        scale = d_head_qk ** -0.5
 | 
			
		||||
        BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8)
 | 
			
		||||
        NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 1
 | 
			
		||||
        assert NK == 1, "NK > 1 is not supported yet"
 | 
			
		||||
        o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
 | 
			
		||||
 | 
			
		||||
        if output_final_state:
 | 
			
		||||
            final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
 | 
			
		||||
        else:
 | 
			
		||||
            final_state = None
 | 
			
		||||
 | 
			
		||||
        grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
        fused_recurrent_fwd_kernel[grid](
 | 
			
		||||
            q, k, v, beta, o, initial_state, final_state,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
            STORE_FINAL_STATE=final_state is not None
 | 
			
		||||
        )
 | 
			
		||||
        o = o.sum(0)
 | 
			
		||||
        ctx.save_for_backward(q, k, v, beta, initial_state)
 | 
			
		||||
        return o, final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, do, d_final_state=None):
 | 
			
		||||
        q, k, v, beta, initial_state = ctx.saved_tensors
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
        scale = d_head_qk ** -0.5
 | 
			
		||||
        BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
 | 
			
		||||
        NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
        assert NK == 1, "NK > 1 is not supported yet"
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 2
 | 
			
		||||
 | 
			
		||||
        dq = q.new_empty(NV, batch_size, n_heads,  seq_len, d_head_qk)
 | 
			
		||||
        dk = q.new_empty(NV, batch_size, n_heads,  seq_len, d_head_qk)
 | 
			
		||||
        dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
 | 
			
		||||
        grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
        dbeta = q.new_empty(NV, batch_size, n_heads, seq_len)
 | 
			
		||||
 | 
			
		||||
        fused_recurrent_bwd_kernel[grid](
 | 
			
		||||
            q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None
 | 
			
		||||
        )
 | 
			
		||||
        dq = dq.sum(0)
 | 
			
		||||
        dk = dk.sum(0)
 | 
			
		||||
        dv = dv.sum(0)
 | 
			
		||||
        dbeta = dbeta.sum(0)
 | 
			
		||||
        return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fused_recurrent_linear_attn_delta_rule(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    beta: torch.Tensor = None,
 | 
			
		||||
    initial_state: torch.Tensor = None,
 | 
			
		||||
    output_final_state: bool = False,
 | 
			
		||||
    normalize: bool = False
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = initial_state.detach()
 | 
			
		||||
    if beta is None:
 | 
			
		||||
        beta = torch.ones_like(q[..., 0])
 | 
			
		||||
    o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state)
 | 
			
		||||
    return o, final_state
 | 
			
		||||
							
								
								
									
										297
									
								
								finetune/lora/v6/fla/ops/delta_rule/utils.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										297
									
								
								finetune/lora/v6/fla/ops/delta_rule/utils.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,297 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
from fla.ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
 | 
			
		||||
# o: cumprod
 | 
			
		||||
# o2: cumprodsum
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"],
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fwd_prepare_wy_repr_kernel(
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    beta,
 | 
			
		||||
    o,
 | 
			
		||||
    o2,
 | 
			
		||||
    T,
 | 
			
		||||
    K,
 | 
			
		||||
    V,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_t, i_bh = tl.program_id(0), tl.program_id(1)
 | 
			
		||||
 | 
			
		||||
    p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
 | 
			
		||||
    p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
 | 
			
		||||
    p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
 | 
			
		||||
    mask_bt = (tl.arange(0, BT) + i_t * BT) < T
 | 
			
		||||
    mask_bk = tl.arange(0, BK) < K
 | 
			
		||||
    mask_bv = tl.arange(0, BV) < V
 | 
			
		||||
    mask_bk = mask_bk[None, :] & mask_bt[:, None]
 | 
			
		||||
    mask_bv = mask_bv[None, :] & mask_bt[:, None]
 | 
			
		||||
    # [BT, BK]
 | 
			
		||||
    b_k = tl.load(p_k, mask=mask_bk, other=0)
 | 
			
		||||
    # [BT,]
 | 
			
		||||
    b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)
 | 
			
		||||
    # [BT, BV]
 | 
			
		||||
    b_v = tl.load(p_v, mask=mask_bv, other=0)
 | 
			
		||||
    b_v = (b_v * b_beta[:, None]).to(b_v.dtype)
 | 
			
		||||
    # [BT, BK]
 | 
			
		||||
    b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
 | 
			
		||||
    # [BT, BT]
 | 
			
		||||
    b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
 | 
			
		||||
    b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
 | 
			
		||||
 | 
			
		||||
    for i in range(BT):
 | 
			
		||||
        mask = tl.arange(0, BT) == i
 | 
			
		||||
        b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
 | 
			
		||||
        b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
 | 
			
		||||
        b_A = tl.where(mask[:, None], b_a, b_A)
 | 
			
		||||
    b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
 | 
			
		||||
    b_A = b_A.to(b_k.dtype)
 | 
			
		||||
    b_w = tl.dot(b_A, b_kb, allow_tf32=False)
 | 
			
		||||
    b_u = tl.dot(b_A, b_v, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
    p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:,  None]) * K + tl.arange(0, BK)[None, :]
 | 
			
		||||
    tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)
 | 
			
		||||
    p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
 | 
			
		||||
    tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"],
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def bwd_prepare_wy_repr_kernel(
 | 
			
		||||
    k, v, beta,
 | 
			
		||||
    o, o2, do, do2,
 | 
			
		||||
    dk, dv, dbeta,
 | 
			
		||||
    NT, K, V, T,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
):
 | 
			
		||||
    i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
 | 
			
		||||
    p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
 | 
			
		||||
    p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
 | 
			
		||||
 | 
			
		||||
    p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
 | 
			
		||||
    mask_bt = (tl.arange(0, BT) + i_t * BT) < T
 | 
			
		||||
    mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]
 | 
			
		||||
    mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]
 | 
			
		||||
    b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)
 | 
			
		||||
 | 
			
		||||
    b_beta = b_beta.to(tl.float32)
 | 
			
		||||
    A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]
 | 
			
		||||
    A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)
 | 
			
		||||
    b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)
 | 
			
		||||
    b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)
 | 
			
		||||
    dA = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
    b_dk = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
    for i in range(BT-1, -1, -1):
 | 
			
		||||
        mask = tl.arange(0, BT) == i
 | 
			
		||||
        attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)
 | 
			
		||||
        do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)
 | 
			
		||||
        dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)
 | 
			
		||||
        b_do = b_do - attn[:, None] * do_[None, :]
 | 
			
		||||
        b_dv = b_dv - attn[:, None] * dv_[None, :]
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
    p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
 | 
			
		||||
    b_v = tl.load(p_v, mask=mask_bv)
 | 
			
		||||
    b_dk += b_do * b_beta[:, None]
 | 
			
		||||
    b_dbeta = tl.sum(b_do * b_k, axis=1)
 | 
			
		||||
    b_dbeta += tl.sum(b_dv * b_v, axis=1)
 | 
			
		||||
    b_v = None
 | 
			
		||||
 | 
			
		||||
    p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
 | 
			
		||||
    p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
 | 
			
		||||
    b_o = tl.load(p_o, mask=mask_bk)
 | 
			
		||||
    b_o2 = tl.load(p_o2, mask=mask_bv)
 | 
			
		||||
 | 
			
		||||
    dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)
 | 
			
		||||
    dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),
 | 
			
		||||
                 allow_tf32=False)
 | 
			
		||||
    dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)
 | 
			
		||||
    b_dv *= b_beta[:, None]
 | 
			
		||||
    p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
 | 
			
		||||
    tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
 | 
			
		||||
 | 
			
		||||
    b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)
 | 
			
		||||
    dA = dA * b_beta[:, None]
 | 
			
		||||
    b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)
 | 
			
		||||
    b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)
 | 
			
		||||
    p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
 | 
			
		||||
    tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
 | 
			
		||||
    p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)
 | 
			
		||||
    tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fwd_prepare_wy_repr(k, v, beta, chunk_size):
 | 
			
		||||
    B, H, T, K, V = *k.shape, v.shape[-1]
 | 
			
		||||
    v_new = torch.empty_like(v)
 | 
			
		||||
    o_cumdecay = torch.empty_like(k)
 | 
			
		||||
    BT = chunk_size
 | 
			
		||||
    NT = triton.cdiv(T, BT)
 | 
			
		||||
    BK = triton.next_power_of_2(K)
 | 
			
		||||
    BV = triton.next_power_of_2(V)
 | 
			
		||||
    fwd_prepare_wy_repr_kernel[(NT, B*H)](
 | 
			
		||||
        k, v, beta, o_cumdecay, v_new,
 | 
			
		||||
        T, K, V, BT, BK, BV
 | 
			
		||||
    )
 | 
			
		||||
    return o_cumdecay, v_new
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):
 | 
			
		||||
    b, h, l, d_k = do.shape
 | 
			
		||||
    d_v = v.shape[-1]
 | 
			
		||||
    BK = triton.next_power_of_2(d_k)
 | 
			
		||||
    BV = triton.next_power_of_2(d_v)
 | 
			
		||||
    c = chunk_size
 | 
			
		||||
    BK = d_k
 | 
			
		||||
    NT = triton.cdiv(l, c)
 | 
			
		||||
    dk = torch.empty_like(k)
 | 
			
		||||
    dv = torch.empty_like(v)
 | 
			
		||||
    dbeta = torch.zeros_like(beta)
 | 
			
		||||
    bwd_prepare_wy_repr_kernel[(NT, b*h)](
 | 
			
		||||
        k, v, beta,
 | 
			
		||||
        o_cumdecay, v_new, do, do2,
 | 
			
		||||
        dk, dv, dbeta,
 | 
			
		||||
        NT, d_k, d_v, l, chunk_size, BK, BV
 | 
			
		||||
    )
 | 
			
		||||
    return dk, dv, dbeta
 | 
			
		||||
 | 
			
		||||
class WYRepresentationPrepration(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, k, v, beta, chunk_size):
 | 
			
		||||
        o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
 | 
			
		||||
        ctx.chunk_size = chunk_size
 | 
			
		||||
        ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new)
 | 
			
		||||
        return o_cumdecay, v_new
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    def backward(ctx, do, do2):
 | 
			
		||||
        k, v, beta, o_cumdecay, v_new = ctx.saved_tensors
 | 
			
		||||
        dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size)
 | 
			
		||||
        return dk, dv, dbeta, None
 | 
			
		||||
 | 
			
		||||
prepare_wy_repr = WYRepresentationPrepration.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def naive(k, v, beta, chunk_size):
 | 
			
		||||
    l_org = k.shape[2]
 | 
			
		||||
    l_new = triton.next_power_of_2(l_org)
 | 
			
		||||
    # pad k, v, beta
 | 
			
		||||
    k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
 | 
			
		||||
    v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
 | 
			
		||||
    beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
 | 
			
		||||
 | 
			
		||||
    k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
 | 
			
		||||
    # k = torch.nn.functional.normalize(k, dim=-1, p=2)
 | 
			
		||||
    beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
 | 
			
		||||
    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
 | 
			
		||||
    k_beta = k * beta[..., None]
 | 
			
		||||
    v = v * beta[..., None]
 | 
			
		||||
    attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
 | 
			
		||||
    attn = attn * beta[..., None]
 | 
			
		||||
    x = attn @ v
 | 
			
		||||
 | 
			
		||||
    o = torch.zeros_like(k)
 | 
			
		||||
    o2 = torch.zeros_like(v)
 | 
			
		||||
 | 
			
		||||
    o[..., 0, :] = k_beta[..., 0, :].clone()
 | 
			
		||||
    o2[..., 0, :] = x[..., 0, :].clone()
 | 
			
		||||
    for i in range(1, chunk_size):
 | 
			
		||||
        o_i = (o[..., :i, :]).clone()
 | 
			
		||||
        o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
 | 
			
		||||
        o2_i = (o2[..., :i, :]).clone()
 | 
			
		||||
        o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
 | 
			
		||||
    return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    torch.set_default_dtype(torch.bfloat16)
 | 
			
		||||
    seq_len = 2048
 | 
			
		||||
    b = 4
 | 
			
		||||
    h = 8
 | 
			
		||||
    k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2)
 | 
			
		||||
    v = torch.randn(b, h, seq_len, 256) 
 | 
			
		||||
    beta = torch.rand(b, h, seq_len).sigmoid()
 | 
			
		||||
    require_grad = True
 | 
			
		||||
    k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
 | 
			
		||||
    do = torch.rand_like(k)
 | 
			
		||||
    do2 = torch.rand_like(v)
 | 
			
		||||
 | 
			
		||||
    print("Start warmup.")
 | 
			
		||||
    o1, o2 = prepare_wy_repr(k, v, beta, 32)
 | 
			
		||||
    # (o1 * do + o2 * do2).sum().backward()
 | 
			
		||||
    o3, o4 = prepare_wy_repr2(k, v, beta, 32)
 | 
			
		||||
    # (o1 * do + o2 * do2).sum().backward()
 | 
			
		||||
    print((o1 - o3).abs().max())
 | 
			
		||||
    print((o2 - o4).abs().max())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    for i in range(30):
 | 
			
		||||
        o1, o2 = prepare_wy_repr(k, v, beta, 32)
 | 
			
		||||
        (o1 * do + o2 * do2).sum().backward()
 | 
			
		||||
        o1, o2 = prepare_wy_repr2(k, v, beta, 32)
 | 
			
		||||
        (o1 * do + o2 * do2).sum().backward()
 | 
			
		||||
 | 
			
		||||
    print("Done warmup.")
 | 
			
		||||
 | 
			
		||||
    import time 
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
    start = time.time()
 | 
			
		||||
 | 
			
		||||
    for i in range(200):
 | 
			
		||||
        o1, o2 = prepare_wy_repr(k, v, beta, 64)
 | 
			
		||||
        (o1 * do + o2 * do2).sum().backward()
 | 
			
		||||
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
    print(time.time() - start)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
    start = time.time()
 | 
			
		||||
 | 
			
		||||
    for i in range(200):
 | 
			
		||||
        o1, o2 = prepare_wy_repr2(k, v, beta, 64)
 | 
			
		||||
        (o1 * do + o2 * do2).sum().backward()
 | 
			
		||||
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
    print(time.time() - start)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
							
								
								
									
										401
									
								
								finetune/lora/v6/fla/ops/delta_rule/wy_fast.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										401
									
								
								finetune/lora/v6/fla/ops/delta_rule/wy_fast.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,401 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
 | 
			
		||||
# o: cumprod
 | 
			
		||||
# o2: cumprodsum
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"], 
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fwd_prepare_wy_repr_kernel(
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    beta,
 | 
			
		||||
    w,  
 | 
			
		||||
    u,
 | 
			
		||||
    A, 
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    T,
 | 
			
		||||
    K,
 | 
			
		||||
    V,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_t, i_bh = tl.program_id(0), tl.program_id(1)
 | 
			
		||||
    
 | 
			
		||||
    b_A = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
    p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
 | 
			
		||||
    b_beta = tl.load(p_beta, boundary_check=(0,))
 | 
			
		||||
 | 
			
		||||
    for i_k in range(tl.cdiv(K, BK)):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
 | 
			
		||||
        b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
    b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
 | 
			
		||||
 | 
			
		||||
    for i in range(1, BT):
 | 
			
		||||
        mask = tl.arange(0, BT) == i
 | 
			
		||||
        b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
 | 
			
		||||
        b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
 | 
			
		||||
        b_A = tl.where(mask[:, None], b_a, b_A)
 | 
			
		||||
 | 
			
		||||
    b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
 | 
			
		||||
 | 
			
		||||
    p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
 | 
			
		||||
    tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    b_A = b_A.to(k.dtype.element_ty)
 | 
			
		||||
 | 
			
		||||
    for i_v in range(tl.cdiv(V, BV)):
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
 | 
			
		||||
        b_u = tl.dot(b_A, b_vb, allow_tf32=False)
 | 
			
		||||
        p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    for i_k in range(tl.cdiv(K, BK)):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
 | 
			
		||||
        b_w = tl.dot(b_A, b_kb, allow_tf32=False)
 | 
			
		||||
        p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"], 
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fwd_recompute_w_u_kernel(
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    beta,
 | 
			
		||||
    w,  
 | 
			
		||||
    u,
 | 
			
		||||
    A, 
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    T,
 | 
			
		||||
    K,
 | 
			
		||||
    V,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_t, i_bh = tl.program_id(0), tl.program_id(1)
 | 
			
		||||
    
 | 
			
		||||
    p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
 | 
			
		||||
    b_beta = tl.load(p_beta, boundary_check=(0,))
 | 
			
		||||
 | 
			
		||||
    p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
 | 
			
		||||
    b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
 | 
			
		||||
 | 
			
		||||
    for i_v in range(tl.cdiv(V, BV)):
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
 | 
			
		||||
        b_u = tl.dot(b_A, b_vb, allow_tf32=False)
 | 
			
		||||
        p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    for i_k in range(tl.cdiv(K, BK)):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
 | 
			
		||||
        b_w = tl.dot(b_A, b_kb, allow_tf32=False)
 | 
			
		||||
        p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({}, num_warps=1),
 | 
			
		||||
        triton.Config({}, num_warps=2),
 | 
			
		||||
        triton.Config({}, num_warps=4),
 | 
			
		||||
        triton.Config({}, num_warps=8),
 | 
			
		||||
        triton.Config({}, num_warps=16),
 | 
			
		||||
        triton.Config({}, num_warps=32),
 | 
			
		||||
    ],
 | 
			
		||||
    key=["BT", "BK", "BV"],
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def bwd_prepare_wy_repr_kernel(
 | 
			
		||||
    k, v, beta, A,  
 | 
			
		||||
    dw, du,
 | 
			
		||||
    dk, dv, dbeta,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    T,
 | 
			
		||||
    K,
 | 
			
		||||
    V,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
 | 
			
		||||
    b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
 | 
			
		||||
 | 
			
		||||
    b_dbeta = tl.zeros([BT], dtype=tl.float32)
 | 
			
		||||
    b_dA = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
    p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
 | 
			
		||||
    b_beta = tl.load(p_beta, boundary_check=(0,))
 | 
			
		||||
 | 
			
		||||
    for i_v in range(tl.cdiv(V, BV)):
 | 
			
		||||
        p_v =  tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
 | 
			
		||||
        b_du = tl.load(p_du, boundary_check=(0, 1))
 | 
			
		||||
        b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
 | 
			
		||||
        b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)
 | 
			
		||||
        b_dv = b_dv_beta * b_beta[:, None]
 | 
			
		||||
        b_dbeta += tl.sum(b_dv_beta * b_v, 1)
 | 
			
		||||
        # store
 | 
			
		||||
        p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    tl.debug_barrier()    
 | 
			
		||||
    b_A2 = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
    for i_k in range(tl.cdiv(K, BK)):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))        
 | 
			
		||||
        b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
 | 
			
		||||
        b_dw = tl.load(p_dw, boundary_check=(0, 1))
 | 
			
		||||
        b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)       
 | 
			
		||||
        b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)
 | 
			
		||||
        b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)
 | 
			
		||||
        b_dk = b_dk_beta * b_beta[:, None]
 | 
			
		||||
        b_dbeta += tl.sum(b_dk_beta * b_k, 1)
 | 
			
		||||
        # store        
 | 
			
		||||
        p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])
 | 
			
		||||
    b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)
 | 
			
		||||
    b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
 | 
			
		||||
    for i in range(BT-1, 0, -1):
 | 
			
		||||
        mask = tl.arange(0, BT) == i
 | 
			
		||||
        b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0) 
 | 
			
		||||
        b_a =  tl.sum(tl.where(mask[:, None], b_A2, 0), 0) 
 | 
			
		||||
        b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1)     
 | 
			
		||||
        b_dA = tl.where(mask[:, None], b_da2, b_dA)
 | 
			
		||||
        b_dA += b_da[None, :] * b_a[:, None]
 | 
			
		||||
 | 
			
		||||
    b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
 | 
			
		||||
    for i_k in range(tl.cdiv(K, BK)):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))        
 | 
			
		||||
        b_dk = tl.load(p_dk, boundary_check=(0, 1))
 | 
			
		||||
        b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
 | 
			
		||||
 | 
			
		||||
        b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
 | 
			
		||||
        b_dbeta += tl.sum(b_dk_beta * b_k, 1)
 | 
			
		||||
        b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) 
 | 
			
		||||
        b_dk += b_dk_beta * b_beta[:, None]        
 | 
			
		||||
        tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    
 | 
			
		||||
    p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
 | 
			
		||||
    tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty),boundary_check=(0,))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fwd_prepare_wy_repr(k, v, beta, BT):
 | 
			
		||||
    B, H, T, K, V = *k.shape, v.shape[-1]
 | 
			
		||||
    u = torch.empty_like(v)
 | 
			
		||||
    w = torch.empty_like(k)
 | 
			
		||||
    NT = triton.cdiv(T, BT)
 | 
			
		||||
    BK = min(triton.next_power_of_2(K), 64)
 | 
			
		||||
    BV = min(triton.next_power_of_2(V), 64)
 | 
			
		||||
    A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)
 | 
			
		||||
    fwd_prepare_wy_repr_kernel[(NT, B*H)](
 | 
			
		||||
        k, v, beta, w, u, A,
 | 
			
		||||
        k.stride(1), k.stride(2), k.stride(3), 
 | 
			
		||||
        v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
        T, K, V, BT, BK, BV
 | 
			
		||||
    )
 | 
			
		||||
    return w, u, A
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fwd_recompute_w_u(k, v, beta, A, BT):
 | 
			
		||||
    B, H, T, K, V = *k.shape, v.shape[-1]
 | 
			
		||||
    u = torch.empty_like(v)
 | 
			
		||||
    w = torch.empty_like(k)
 | 
			
		||||
    NT = triton.cdiv(T, BT)
 | 
			
		||||
    BK = min(triton.next_power_of_2(K), 64)
 | 
			
		||||
    BV = min(triton.next_power_of_2(V), 64)
 | 
			
		||||
    fwd_recompute_w_u_kernel[(NT, B*H)](
 | 
			
		||||
        k, v, beta, w, u, A,
 | 
			
		||||
        k.stride(1), k.stride(2), k.stride(3), 
 | 
			
		||||
        v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
        T, K, V, BT, BK, BV
 | 
			
		||||
    )
 | 
			
		||||
    return w, u
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):
 | 
			
		||||
    B, H, T, K, V = *k.shape, v.shape[-1]
 | 
			
		||||
 | 
			
		||||
    NT = triton.cdiv(T, BT)
 | 
			
		||||
    BK = min(triton.next_power_of_2(K), 64)
 | 
			
		||||
    BV = min(triton.next_power_of_2(V), 64)
 | 
			
		||||
    NT = triton.cdiv(T, BT)
 | 
			
		||||
    dk = torch.empty_like(k)
 | 
			
		||||
    dv = torch.empty_like(v).contiguous()
 | 
			
		||||
    dbeta = torch.zeros_like(beta)
 | 
			
		||||
 | 
			
		||||
    bwd_prepare_wy_repr_kernel[(NT, B*H)](
 | 
			
		||||
        k, v, beta, A,
 | 
			
		||||
        dw, du,  
 | 
			
		||||
        dk, dv, dbeta,
 | 
			
		||||
        k.stride(1), k.stride(2), k.stride(3), 
 | 
			
		||||
        v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
        T, K, V, BT, BK, BV
 | 
			
		||||
    )
 | 
			
		||||
    return dk, dv, dbeta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WYRepresentationPrepration(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, k, v, beta, chunk_size):
 | 
			
		||||
        ctx.BT = chunk_size
 | 
			
		||||
        w, u, A = fwd_prepare_wy_repr(k, v, beta,  ctx.BT)
 | 
			
		||||
        ctx.save_for_backward(k, v, beta, A)
 | 
			
		||||
        return w, u
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    def backward(ctx, dw, du):
 | 
			
		||||
        k, v, beta, A = ctx.saved_tensors
 | 
			
		||||
        BT = ctx.BT
 | 
			
		||||
        dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT)
 | 
			
		||||
        return dk, dv, dbeta, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
prepare_wy_repr = WYRepresentationPrepration.apply
 | 
			
		||||
 | 
			
		||||
def naive(k, v, beta, chunk_size):
 | 
			
		||||
    l_org = k.shape[2]
 | 
			
		||||
    l_new = triton.next_power_of_2(l_org)
 | 
			
		||||
    # pad k, v, beta
 | 
			
		||||
    k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
 | 
			
		||||
    v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
 | 
			
		||||
    beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
 | 
			
		||||
 | 
			
		||||
    k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
 | 
			
		||||
    # k = torch.nn.functional.normalize(k, dim=-1, p=2)
 | 
			
		||||
    beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
 | 
			
		||||
    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
 | 
			
		||||
    k_beta = k * beta[..., None]
 | 
			
		||||
    v = v * beta[..., None]
 | 
			
		||||
    attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
 | 
			
		||||
    attn = attn * beta[..., None]
 | 
			
		||||
    x = attn @ v
 | 
			
		||||
 | 
			
		||||
    o = torch.zeros_like(k)
 | 
			
		||||
    o2 = torch.zeros_like(v)
 | 
			
		||||
 | 
			
		||||
    o[..., 0, :] = k_beta[..., 0, :].clone()
 | 
			
		||||
    o2[..., 0, :] = x[..., 0, :].clone()
 | 
			
		||||
    for i in range(1, chunk_size):
 | 
			
		||||
        o_i = (o[..., :i, :]).clone()
 | 
			
		||||
        o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
 | 
			
		||||
        o2_i = (o2[..., :i, :]).clone()
 | 
			
		||||
        o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
 | 
			
		||||
    return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    torch.set_default_dtype(torch.float32)
 | 
			
		||||
    seq_len = 1024
 | 
			
		||||
    b = 4
 | 
			
		||||
    h = 4
 | 
			
		||||
    k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)
 | 
			
		||||
    v = torch.randn(b, h, seq_len, 128) 
 | 
			
		||||
    beta = torch.rand(b, h, seq_len).sigmoid()
 | 
			
		||||
    # beta = torch.ones(b, h, seq_len)
 | 
			
		||||
    require_grad = True
 | 
			
		||||
 | 
			
		||||
    k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
 | 
			
		||||
    do = torch.rand_like(k)
 | 
			
		||||
    do2 = torch.rand_like(v)
 | 
			
		||||
 | 
			
		||||
    o1, o2 = naive(k.clone(), v.clone(), beta.clone(), 64)
 | 
			
		||||
    if require_grad:
 | 
			
		||||
        o1.backward(do, retain_graph=True)
 | 
			
		||||
        o2.backward(do2, retain_graph=True)
 | 
			
		||||
 | 
			
		||||
        k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad
 | 
			
		||||
        k.grad = v.grad = beta.grad = None
 | 
			
		||||
 | 
			
		||||
    o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone())
 | 
			
		||||
    print((o1-o3).abs().max())
 | 
			
		||||
    print((o2-o4).abs().max())
 | 
			
		||||
 | 
			
		||||
    if require_grad:
 | 
			
		||||
        o3.backward(do, retain_graph=True)
 | 
			
		||||
        o4.backward(do2, retain_graph=True)
 | 
			
		||||
        k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad
 | 
			
		||||
        print((k_grad2-k_grad).abs().max())
 | 
			
		||||
        print((v_grad2-v_grad).abs().max())
 | 
			
		||||
        print((beta_grad2-beta_grad).abs().max())
 | 
			
		||||
    breakpoint()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										11
									
								
								finetune/lora/v6/fla/ops/gla/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								finetune/lora/v6/fla/ops/gla/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .chunk import chunk_gla
 | 
			
		||||
from .chunk_fuse import fused_chunk_gla
 | 
			
		||||
from .recurrent_fuse import fused_recurrent_gla
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'chunk_gla',
 | 
			
		||||
    'fused_chunk_gla',
 | 
			
		||||
    'fused_recurrent_gla'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										734
									
								
								finetune/lora/v6/fla/ops/gla/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										734
									
								
								finetune/lora/v6/fla/ops/gla/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,734 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
from fla.ops.utils import chunk_reversed_cumsum_fwd
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({'BS': 16}, num_warps=2),
 | 
			
		||||
        triton.Config({'BS': 16}, num_warps=4),
 | 
			
		||||
        triton.Config({'BS': 16}, num_warps=8),
 | 
			
		||||
        triton.Config({'BS': 32}, num_warps=2),
 | 
			
		||||
        triton.Config({'BS': 32}, num_warps=4),
 | 
			
		||||
        triton.Config({'BS': 32}, num_warps=8),
 | 
			
		||||
        triton.Config({'BS': 64}, num_warps=2),
 | 
			
		||||
        triton.Config({'BS': 64}, num_warps=4),
 | 
			
		||||
        triton.Config({'BS': 64}, num_warps=8),
 | 
			
		||||
    ],
 | 
			
		||||
    key=['S']
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_gla_fwd_kernel_cum(
 | 
			
		||||
    s,
 | 
			
		||||
    o,
 | 
			
		||||
    s_s_h,
 | 
			
		||||
    s_s_t,
 | 
			
		||||
    s_s_d,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    S: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BS: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
    m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
 | 
			
		||||
 | 
			
		||||
    p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
 | 
			
		||||
    p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
 | 
			
		||||
    # [BT, BS]
 | 
			
		||||
    b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
    b_o = tl.dot(m_s, b_s, allow_tf32=False)
 | 
			
		||||
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_gla_fwd_kernel_h(
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    g,
 | 
			
		||||
    h,
 | 
			
		||||
    h0,
 | 
			
		||||
    ht,
 | 
			
		||||
    s_k_h,
 | 
			
		||||
    s_k_t,
 | 
			
		||||
    s_k_d,
 | 
			
		||||
    s_v_h,
 | 
			
		||||
    s_v_t,
 | 
			
		||||
    s_v_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    s_h_d,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    NT: tl.constexpr,
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    STORE_FINAL_STATE: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    b_h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
    for i_t in range(NT):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
 | 
			
		||||
        tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_g = tl.load(p_g, boundary_check=(0, 1))
 | 
			
		||||
        if i_t < NT - 1:
 | 
			
		||||
            # [BK,]
 | 
			
		||||
            b_gn = tl.load(p_gn, boundary_check=(0,))
 | 
			
		||||
        else:
 | 
			
		||||
            b_gn = tl.min(b_g, axis=1)
 | 
			
		||||
        b_h *= tl.exp(b_gn)[:, None]
 | 
			
		||||
        b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
 | 
			
		||||
        b_h += tl.dot(b_k, b_v, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
    if STORE_FINAL_STATE:
 | 
			
		||||
        p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_gla_fwd_kernel_intra(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    g,
 | 
			
		||||
    A,
 | 
			
		||||
    s_k_h,
 | 
			
		||||
    s_k_t,
 | 
			
		||||
    s_k_d,
 | 
			
		||||
    scale,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BC: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    NC: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
 | 
			
		||||
    n_bh = tl.num_programs(2)
 | 
			
		||||
 | 
			
		||||
    if i_i > i_j:
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
        p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
 | 
			
		||||
        p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
 | 
			
		||||
        p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
        p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
 | 
			
		||||
        # [BK,]
 | 
			
		||||
        b_gn = tl.load(p_gn, boundary_check=(0,))
 | 
			
		||||
        # [BC, BK]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_g = tl.load(p_g, boundary_check=(0, 1))
 | 
			
		||||
        b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)
 | 
			
		||||
        # [BK, BC]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        b_gk = tl.load(p_gk, boundary_check=(0, 1))
 | 
			
		||||
        b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
 | 
			
		||||
        # [BC, BC]
 | 
			
		||||
        b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
 | 
			
		||||
        tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    elif i_i == i_j:
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
        p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
        p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
        # [BC, BK]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_g = tl.load(p_g, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        o_i = tl.arange(0, BC)
 | 
			
		||||
        o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
 | 
			
		||||
        m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
 | 
			
		||||
        for j in range(0, BC):
 | 
			
		||||
            # [BK,]
 | 
			
		||||
            b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
 | 
			
		||||
            b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
 | 
			
		||||
            # [BC,]
 | 
			
		||||
            b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)
 | 
			
		||||
            b_A = tl.where(o_i >= j, b_A, 0.)
 | 
			
		||||
            tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
 | 
			
		||||
 | 
			
		||||
            p_k = tl.advance(p_k, (K,))
 | 
			
		||||
            p_gk = tl.advance(p_gk, (K,))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_gla_fwd_kernel_inter(
 | 
			
		||||
    q,
 | 
			
		||||
    v,
 | 
			
		||||
    g,
 | 
			
		||||
    h,
 | 
			
		||||
    o,
 | 
			
		||||
    A,
 | 
			
		||||
    s_k_h,
 | 
			
		||||
    s_k_t,
 | 
			
		||||
    s_k_d,
 | 
			
		||||
    s_v_h,
 | 
			
		||||
    s_v_t,
 | 
			
		||||
    s_v_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    s_h_d,
 | 
			
		||||
    scale,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    b_o = tl.zeros([BT, BV], dtype=tl.float32)
 | 
			
		||||
    for i_k in range(tl.cdiv(K, BK)):
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_q = (b_q * scale).to(b_q.dtype)
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_g = tl.load(p_g, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_h = tl.load(p_h, boundary_check=(0, 1))
 | 
			
		||||
        # works but dkw, owing to divine benevolence
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        if i_k >= 0:
 | 
			
		||||
            b_o += tl.dot(b_qg, b_h, allow_tf32=False)
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
 | 
			
		||||
    # [BT, BV]
 | 
			
		||||
    b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
    # [BT, BT]
 | 
			
		||||
    b_A = tl.load(p_A, boundary_check=(0, 1))
 | 
			
		||||
    b_o += tl.dot(b_A, b_v, allow_tf32=False)
 | 
			
		||||
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_gla_bwd_kernel_dh(
 | 
			
		||||
    q,
 | 
			
		||||
    g,
 | 
			
		||||
    do,
 | 
			
		||||
    dh,
 | 
			
		||||
    s_k_h,
 | 
			
		||||
    s_k_t,
 | 
			
		||||
    s_k_d,
 | 
			
		||||
    s_v_h,
 | 
			
		||||
    s_v_t,
 | 
			
		||||
    s_v_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    s_h_d,
 | 
			
		||||
    scale,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    NT: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    b_dh = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
    for i_t in range(NT - 1, -1, -1):
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_q = (b_q * scale).to(b_q.dtype)
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        # [BK,]
 | 
			
		||||
        b_gn = tl.load(p_gn, boundary_check=(0,))
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_dh *= tl.exp(b_gn)[:, None]
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_g = tl.load(p_g, boundary_check=(0, 1))
 | 
			
		||||
        b_q = (b_q * tl.exp(b_g)).to(b_q.dtype)
 | 
			
		||||
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_dh += tl.dot(b_q, b_do, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_gla_bwd_kernel_inter(
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    h,
 | 
			
		||||
    g,
 | 
			
		||||
    A,
 | 
			
		||||
    do,
 | 
			
		||||
    dh,
 | 
			
		||||
    dq,
 | 
			
		||||
    dk,
 | 
			
		||||
    dv,
 | 
			
		||||
    dA,
 | 
			
		||||
    s_k_h,
 | 
			
		||||
    s_k_t,
 | 
			
		||||
    s_k_d,
 | 
			
		||||
    s_v_h,
 | 
			
		||||
    s_v_t,
 | 
			
		||||
    s_v_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    s_h_d,
 | 
			
		||||
    scale,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    n_bh = tl.num_programs(2)
 | 
			
		||||
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
    p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
 | 
			
		||||
 | 
			
		||||
    # [BT, BK]
 | 
			
		||||
    b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
    b_gk = tl.load(p_gk, boundary_check=(0, 1))
 | 
			
		||||
    b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
 | 
			
		||||
    b_k = (b_k * b_gn).to(b_k.dtype)
 | 
			
		||||
    # [BT, BT]
 | 
			
		||||
    b_A = tl.load(p_A, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    b_dq = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
    b_dk = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
    b_dA = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
    for i_v in range(tl.cdiv(V, BV)):
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BV, BK]
 | 
			
		||||
        b_h = tl.load(p_h, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_dh = tl.load(p_dh, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
 | 
			
		||||
        if i_k == 0:
 | 
			
		||||
            b_dv += tl.dot(b_A, b_do, allow_tf32=False)
 | 
			
		||||
        b_do = (b_do * scale).to(b_do.dtype)
 | 
			
		||||
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BT]
 | 
			
		||||
        b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_dq += tl.dot(b_do, b_h, allow_tf32=False)
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
 | 
			
		||||
    b_dq = b_dq * tl.exp(b_gk)
 | 
			
		||||
    b_dk = b_dk * b_gn
 | 
			
		||||
 | 
			
		||||
    p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
 | 
			
		||||
    tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
    m_s = o_i[:, None] >= o_i[None, :]
 | 
			
		||||
    # [BT, BT]
 | 
			
		||||
    b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
 | 
			
		||||
    if i_k == 0:
 | 
			
		||||
        tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_gla_bwd_kernel_intra(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    g,
 | 
			
		||||
    dA,
 | 
			
		||||
    dq,
 | 
			
		||||
    dk,
 | 
			
		||||
    dg,
 | 
			
		||||
    s_k_h,
 | 
			
		||||
    s_k_t,
 | 
			
		||||
    s_k_d,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BC: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    NC: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    i_t, i_i = i_c // NC, i_c % NC
 | 
			
		||||
 | 
			
		||||
    p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
    p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
    # [BK,]
 | 
			
		||||
    b_gn = tl.load(p_gn, boundary_check=(0,))
 | 
			
		||||
    # [BC, BK]
 | 
			
		||||
    b_g = tl.load(p_g, boundary_check=(0, 1))
 | 
			
		||||
    b_dq = tl.zeros([BC, BK], dtype=tl.float32)
 | 
			
		||||
    for i_j in range(0, i_i):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
        p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
        p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
 | 
			
		||||
        # [BC, BK]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        b_gk = tl.load(p_gk, boundary_check=(0, 1))
 | 
			
		||||
        b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
 | 
			
		||||
        # [BC, BC]
 | 
			
		||||
        b_dA = tl.load(p_dA, boundary_check=(0, 1))
 | 
			
		||||
        # [BC, BK]
 | 
			
		||||
        b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
 | 
			
		||||
    b_dq *= tl.exp(b_g - b_gn[None, :])
 | 
			
		||||
 | 
			
		||||
    o_i = tl.arange(0, BC)
 | 
			
		||||
    o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
 | 
			
		||||
    m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
 | 
			
		||||
    for j in range(0, BC):
 | 
			
		||||
        p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
        p_gkj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
        # [BC,]
 | 
			
		||||
        b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
 | 
			
		||||
        # [BK,]
 | 
			
		||||
        b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
 | 
			
		||||
        b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32)
 | 
			
		||||
        # [BC, BK]
 | 
			
		||||
        m_i = o_i[:, None] >= j
 | 
			
		||||
        # [BC, BK]
 | 
			
		||||
        b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.)
 | 
			
		||||
    p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
 | 
			
		||||
    b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
 | 
			
		||||
    tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
    p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
    p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
    # [BK,]
 | 
			
		||||
    b_gn = tl.load(p_gn, boundary_check=(0,))
 | 
			
		||||
    # [BC, BK]
 | 
			
		||||
    b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
    b_gk = tl.load(p_gk, boundary_check=(0, 1))
 | 
			
		||||
    b_dk = tl.zeros([BC, BK], dtype=tl.float32)
 | 
			
		||||
    for i_j in range(i_i + 1, NC):
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
        p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
        p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
 | 
			
		||||
        # [BC, BK]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_g = tl.load(p_g, boundary_check=(0, 1))
 | 
			
		||||
        b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype)
 | 
			
		||||
        # [BC, BC]
 | 
			
		||||
        b_dA = tl.load(p_dA, boundary_check=(0, 1))
 | 
			
		||||
        # [BC, BK]
 | 
			
		||||
        b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
 | 
			
		||||
    b_dk *= tl.exp(b_gn[None, :] - b_gk)
 | 
			
		||||
 | 
			
		||||
    o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
 | 
			
		||||
    for j in range(0, BC):
 | 
			
		||||
        p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
        p_gqj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
 | 
			
		||||
        # [BC,]
 | 
			
		||||
        b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
 | 
			
		||||
        # [BK,]
 | 
			
		||||
        b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
 | 
			
		||||
        b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
 | 
			
		||||
        # [BC, BK]
 | 
			
		||||
        m_i = o_i[:, None] <= j
 | 
			
		||||
        b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
 | 
			
		||||
 | 
			
		||||
    p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
    p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
    p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
 | 
			
		||||
 | 
			
		||||
    b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
    b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
 | 
			
		||||
    b_dg = b_q * b_dq - b_k * b_dk
 | 
			
		||||
    tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChunkGLAFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level):
 | 
			
		||||
        B, H, T, K, V = *q.shape, v.shape[-1]
 | 
			
		||||
        BT, BC = 64, 16
 | 
			
		||||
        BK = min(64, triton.next_power_of_2(K))
 | 
			
		||||
        BV = min(64, triton.next_power_of_2(V))
 | 
			
		||||
        NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
 | 
			
		||||
        NK = triton.cdiv(K, BK)
 | 
			
		||||
        NV = triton.cdiv(V, BV)
 | 
			
		||||
        num_warps = 4 if BK == 64 else 2
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
 | 
			
		||||
        def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
 | 
			
		||||
            NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
 | 
			
		||||
            h = q.new_empty(B, H, NT * K, V)
 | 
			
		||||
            grid = (NV, NK, B * H)
 | 
			
		||||
            chunk_gla_fwd_kernel_h[grid](
 | 
			
		||||
                k, v, g, h, h0, ht,
 | 
			
		||||
                k.stride(1), k.stride(2), k.stride(3),
 | 
			
		||||
                v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
                h.stride(1), h.stride(2), h.stride(3),
 | 
			
		||||
                T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
                USE_INITIAL_STATE=h0 is not None,
 | 
			
		||||
                STORE_FINAL_STATE=ht is not None,
 | 
			
		||||
                num_warps=num_warps,
 | 
			
		||||
                num_stages=num_stages
 | 
			
		||||
            )
 | 
			
		||||
            return h
 | 
			
		||||
 | 
			
		||||
        final_state = None
 | 
			
		||||
        if output_final_state:
 | 
			
		||||
            final_state = q.new_empty(B, H, K, V, dtype=torch.float)
 | 
			
		||||
 | 
			
		||||
        g_org, g = g, torch.empty_like(g, dtype=torch.float)
 | 
			
		||||
        def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
 | 
			
		||||
        # keep cummulative normalizer in fp32
 | 
			
		||||
        # this kernel is equivalent to
 | 
			
		||||
        # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
 | 
			
		||||
        chunk_gla_fwd_kernel_cum[grid](
 | 
			
		||||
            g_org, g,
 | 
			
		||||
            g.stride(1), g.stride(2), g.stride(3),
 | 
			
		||||
            T=T, S=K, BT=BT
 | 
			
		||||
        )
 | 
			
		||||
        h = fwd_inner(
 | 
			
		||||
            q=q, k=k, v=v, g=g,
 | 
			
		||||
            B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
            h0=initial_state if initial_state is not None else None,
 | 
			
		||||
            ht=final_state if final_state is not None else None
 | 
			
		||||
        )
 | 
			
		||||
        A = q.new_zeros(NK, B, H, T, BT)
 | 
			
		||||
        grid = (NK, NT * NC * NC, B * H)
 | 
			
		||||
        chunk_gla_fwd_kernel_intra[grid](
 | 
			
		||||
            q, k, g, A,
 | 
			
		||||
            k.stride(1), k.stride(2), k.stride(3),
 | 
			
		||||
            scale,
 | 
			
		||||
            T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
        A = A.sum(0, dtype=A.dtype)
 | 
			
		||||
        o = torch.empty_like(v)
 | 
			
		||||
        grid = (NV, NT, B * H)
 | 
			
		||||
        chunk_gla_fwd_kernel_inter[grid](
 | 
			
		||||
            q, v, g, h, o, A,
 | 
			
		||||
            k.stride(1), k.stride(2), k.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            h.stride(1), h.stride(2), h.stride(3),
 | 
			
		||||
            scale,
 | 
			
		||||
            T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
        if checkpoint_level >= 1:
 | 
			
		||||
            del g
 | 
			
		||||
            g = g_org
 | 
			
		||||
        if checkpoint_level > 1:
 | 
			
		||||
            del h
 | 
			
		||||
            h, initial_state = None, None
 | 
			
		||||
 | 
			
		||||
        ctx.save_for_backward(q, k, v, g, h, initial_state, A)
 | 
			
		||||
        ctx.BT = BT
 | 
			
		||||
        ctx.scale = scale
 | 
			
		||||
        ctx.checkpoint_level = checkpoint_level
 | 
			
		||||
        return o, final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, do, dht=None):
 | 
			
		||||
        q, k, v, g, h, initial_state, A = ctx.saved_tensors
 | 
			
		||||
        B, H, T, K, V = *q.shape, v.shape[-1]
 | 
			
		||||
        BT, BC = ctx.BT, 16
 | 
			
		||||
        BK = min(64, triton.next_power_of_2(K))
 | 
			
		||||
        BV = min(64, triton.next_power_of_2(V))
 | 
			
		||||
        NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
 | 
			
		||||
        NK = triton.cdiv(K, BK)
 | 
			
		||||
        num_warps = 4 if BK == 64 else 2
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
 | 
			
		||||
        def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
 | 
			
		||||
            NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
 | 
			
		||||
            h = q.new_empty(B, H, NT * K, V)
 | 
			
		||||
            grid = (NV, NK, B * H)
 | 
			
		||||
            chunk_gla_fwd_kernel_h[grid](
 | 
			
		||||
                k, v, g, h, h0, ht,
 | 
			
		||||
                k.stride(1), k.stride(2), k.stride(3),
 | 
			
		||||
                v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
                h.stride(1), h.stride(2), h.stride(3),
 | 
			
		||||
                T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
                USE_INITIAL_STATE=h0 is not None,
 | 
			
		||||
                STORE_FINAL_STATE=ht is not None,
 | 
			
		||||
                num_warps=num_warps,
 | 
			
		||||
                num_stages=num_stages
 | 
			
		||||
            )
 | 
			
		||||
            return h
 | 
			
		||||
 | 
			
		||||
        def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, NT, scale):
 | 
			
		||||
            NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
 | 
			
		||||
            dh = q.new_empty(B, H, NT * K, V)
 | 
			
		||||
            grid = (NK, NV, B * H)
 | 
			
		||||
            chunk_gla_bwd_kernel_dh[grid](
 | 
			
		||||
                q, g, do, dh,
 | 
			
		||||
                q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
                do.stride(1), do.stride(2), do.stride(3),
 | 
			
		||||
                dh.stride(1), dh.stride(2), dh.stride(3),
 | 
			
		||||
                scale,
 | 
			
		||||
                T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
                num_warps=num_warps,
 | 
			
		||||
                num_stages=num_stages
 | 
			
		||||
            )
 | 
			
		||||
            return dh
 | 
			
		||||
 | 
			
		||||
        if ctx.checkpoint_level >= 1:
 | 
			
		||||
            # save the original g and compute its fp32 cumsum during the backward pass for memory consideration
 | 
			
		||||
            g_org, g = g, torch.zeros_like(g, dtype=torch.float)
 | 
			
		||||
            def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
 | 
			
		||||
            # keep cummulative normalizer in fp32
 | 
			
		||||
            # this kernel is equivalent to
 | 
			
		||||
            # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
 | 
			
		||||
            chunk_gla_fwd_kernel_cum[grid](
 | 
			
		||||
                g_org, g,
 | 
			
		||||
                g.stride(1), g.stride(2), g.stride(3),
 | 
			
		||||
                T=T, S=K, BT=BT
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # rerun the forward pass to get h if checkpoint_level >= 1
 | 
			
		||||
        if ctx.checkpoint_level > 1:
 | 
			
		||||
            h = fwd_inner(
 | 
			
		||||
                q=q, k=k, v=v, g=g,
 | 
			
		||||
                B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
                h0=initial_state if initial_state is not None else None,
 | 
			
		||||
                ht=None
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        scale = ctx.scale
 | 
			
		||||
        dh = bwd_inner(
 | 
			
		||||
            q, g, do,
 | 
			
		||||
            B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
            scale=scale
 | 
			
		||||
        )
 | 
			
		||||
        dq = torch.empty_like(q, dtype=torch.float)
 | 
			
		||||
        dk = torch.empty_like(k, dtype=torch.float)
 | 
			
		||||
        dg = torch.empty_like(k, dtype=torch.float)
 | 
			
		||||
        dv = v.new_empty(NK, *v.shape)
 | 
			
		||||
        dA = q.new_zeros(B, H, T, BT)
 | 
			
		||||
        grid = (NK, NT, B * H)
 | 
			
		||||
        chunk_gla_bwd_kernel_inter[grid](
 | 
			
		||||
            k, v, h, g, A, do, dh, dq, dk, dv, dA,
 | 
			
		||||
            k.stride(1), k.stride(2), k.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            h.stride(1), h.stride(2), h.stride(3),
 | 
			
		||||
            scale,
 | 
			
		||||
            T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
        dv = dv.sum(0, dtype=dv.dtype)
 | 
			
		||||
        grid = (NK, NT * NC, B * H)
 | 
			
		||||
        chunk_gla_bwd_kernel_intra[grid](
 | 
			
		||||
            q, k, g, dA, dq, dk, dg,
 | 
			
		||||
            k.stride(1), k.stride(2), k.stride(3),
 | 
			
		||||
            T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        dq = dq.to(q.dtype)
 | 
			
		||||
        dk = dk.to(q.dtype)
 | 
			
		||||
        # reversed cumsum, equivalent to:
 | 
			
		||||
        #
 | 
			
		||||
        # def reversed_cumsum(x, dim=-1):
 | 
			
		||||
        #     c = x.cumsum(dim)
 | 
			
		||||
        #     return x + c.index_select(dim, x.new_tensor([c.shape[dim]-1], dtype=torch.long)) - c
 | 
			
		||||
        dg = chunk_reversed_cumsum_fwd(dg).to(k.dtype)
 | 
			
		||||
        return dq, dk, dv, dg, None, None, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chunk_gla(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    g: torch.Tensor,
 | 
			
		||||
    scale: Optional[int] = None,
 | 
			
		||||
    initial_state: torch.Tensor = None,
 | 
			
		||||
    output_final_state: bool = False,
 | 
			
		||||
    checkpoint_level: Optional[int] = 2
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Args:
 | 
			
		||||
        q (torch.Tensor):
 | 
			
		||||
            queries of shape `(B, H, T, K)`
 | 
			
		||||
        k (torch.Tensor):
 | 
			
		||||
            keys of shape `(B, H, T, K)`
 | 
			
		||||
        v (torch.Tensor):
 | 
			
		||||
            values of shape `(B, H, T, V)`
 | 
			
		||||
        g (torch.Tensor):
 | 
			
		||||
            Forget gates of shape `(B, H, T, K)` applied to keys.
 | 
			
		||||
        scale (Optional[int]):
 | 
			
		||||
            Scale factor for the GLA attention scores.
 | 
			
		||||
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
 | 
			
		||||
        initial_state (Optional[torch.Tensor]):
 | 
			
		||||
            Initial state of shape `(B, H, K, V)`. Default: `None`.
 | 
			
		||||
        output_final_state (Optional[bool]):
 | 
			
		||||
            Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
 | 
			
		||||
        checkpoint_level (Optional[int]):
 | 
			
		||||
            Checkpointing level; higher values will save more memories and do more recomputations during backward.
 | 
			
		||||
            Default: `0`:
 | 
			
		||||
            - Level `0`: no memory saved, no recomputation.
 | 
			
		||||
            - Level `1`: recompute the fp32 cumulative values during backward.
 | 
			
		||||
            - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
 | 
			
		||||
    """
 | 
			
		||||
    assert checkpoint_level in [0, 1, 2]
 | 
			
		||||
    if scale is None:
 | 
			
		||||
        scale = q.shape[-1] ** -0.5
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = initial_state.detach()
 | 
			
		||||
    o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)
 | 
			
		||||
    return o, final_state
 | 
			
		||||
							
								
								
									
										548
									
								
								finetune/lora/v6/fla/ops/gla/chunk_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										548
									
								
								finetune/lora/v6/fla/ops/gla/chunk_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,548 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2023, Songlin Yang
 | 
			
		||||
# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635
 | 
			
		||||
# on-the-fly computation without materializing hidden statets into HBMs
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
from packaging import version
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
 | 
			
		||||
from fla.ops.gla.chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum,
 | 
			
		||||
                                    prepare_qg_kg)
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
inv_ln2 = 1.44269504
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_chunk_gla_fwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_K]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V]
 | 
			
		||||
    g,  # cumulative sum of log decay [B, H, L, D_head_K]
 | 
			
		||||
    o,  # output [B, H, L, D_head_V]
 | 
			
		||||
 | 
			
		||||
    initial_state,  # initial state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
    final_state,  # final state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    B,  # batch size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BT: tl.constexpr,  # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    STORE_FINAL_STATE: tl.constexpr,
 | 
			
		||||
    CHECK: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    # indices
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    b_h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    # make block pointers
 | 
			
		||||
    p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
    
 | 
			
		||||
    mask = (i_k * BK + tl.arange(0, BK)) < DK
 | 
			
		||||
 | 
			
		||||
    for i in range(0, tl.cdiv(T, BT)):
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_o = tl.zeros([BT, BV], dtype=tl.float32)
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        if CHECK and i == 0:
 | 
			
		||||
            b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
 | 
			
		||||
            b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
 | 
			
		||||
        else:
 | 
			
		||||
            b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
 | 
			
		||||
            b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
        tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        p_q = tl.advance(p_q, (BT, 0))
 | 
			
		||||
        p_k = tl.advance(p_k, (0, BT))
 | 
			
		||||
        p_v = tl.advance(p_v, (BT, 0))
 | 
			
		||||
        p_o = tl.advance(p_o, (BT, 0))
 | 
			
		||||
        p_db += BT * DK
 | 
			
		||||
 | 
			
		||||
    if STORE_FINAL_STATE:
 | 
			
		||||
        p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_chunk_gla_bwd_kernel(
 | 
			
		||||
    q, k, v, g,
 | 
			
		||||
    do,  # gradient of output [B, H, L, D_head_V]
 | 
			
		||||
    dq,  # gradient of query [NV, B, H, L, D_head_K]
 | 
			
		||||
    dk,  # gradient of key [NV, B, H, L, D_head_K]
 | 
			
		||||
    dv,  # gradient of value [NK, B, H, L, D_head_V]
 | 
			
		||||
 | 
			
		||||
    initial_state,  # initial state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    B,  # batch_size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    # clamp_min,  # minimum log value of the gate for numerical stability. default: -5
 | 
			
		||||
    BT: tl.constexpr,  # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    CHECK: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    # [BV, BK]
 | 
			
		||||
    b_h = tl.zeros([BV, BK], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
 | 
			
		||||
        b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
    
 | 
			
		||||
    mask = (i_k * BK + tl.arange(0, BK)) < DK    
 | 
			
		||||
    for i in range(0, tl.cdiv(T, BT)):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        b_dq = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
        # [BT, DK]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2
 | 
			
		||||
        d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
        # [DV, BT]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, DV]
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
        # [DV, DK]
 | 
			
		||||
        if CHECK and i == 0:
 | 
			
		||||
            b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
 | 
			
		||||
            b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
 | 
			
		||||
        else:
 | 
			
		||||
            b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
 | 
			
		||||
            b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
 | 
			
		||||
        b_dq *= scale
 | 
			
		||||
        tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    # sync threads
 | 
			
		||||
    b_h = None
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
    # [BK, BV]
 | 
			
		||||
    b_dh = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    # cum = tl.zeros([BK], dtype=tl.float32)
 | 
			
		||||
    for i in range(1, tl.cdiv(T, BT) + 1):
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK),
 | 
			
		||||
                                 (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV),
 | 
			
		||||
                                 (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        # [DK, BT]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, DK]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, DV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
        b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
        # inter-chunk
 | 
			
		||||
        # [DK, DV]
 | 
			
		||||
        if CHECK and i == 1:
 | 
			
		||||
            b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
 | 
			
		||||
            b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
 | 
			
		||||
            b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
 | 
			
		||||
        else:
 | 
			
		||||
            b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
 | 
			
		||||
            b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
 | 
			
		||||
            b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
        tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fwd_inner_chunk(
 | 
			
		||||
    q, k, g, A,
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
    B,  # batch_size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    # clamp_min,  # minimum log value of the gate for numerical stability. default: -5
 | 
			
		||||
    BT: tl.constexpr,  # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
 | 
			
		||||
    b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
    p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
 | 
			
		||||
    b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    mask = (i_k * BK + tl.arange(0, BK)) < DK
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
 | 
			
		||||
    p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
 | 
			
		||||
    p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    for i in range(BT):
 | 
			
		||||
        _q = tl.load(p_q, mask=mask, other=0) * scale
 | 
			
		||||
        gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        s = _q[None, :] * b_k * tl.math.exp2(gq[None, :] - b_g)
 | 
			
		||||
        score = tl.sum(s, axis=1)
 | 
			
		||||
        score = tl.where(o_i <= i, score, 0)
 | 
			
		||||
        tl.store(p_A, score.to(p_A.dtype.element_ty))
 | 
			
		||||
        p_q += DK
 | 
			
		||||
        p_gq += DK
 | 
			
		||||
        p_A += BT
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def bwd_inner_chunk(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    g,
 | 
			
		||||
    dA,
 | 
			
		||||
    dq,
 | 
			
		||||
    dk,
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
    B,  # batch_size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    # clamp_min,  # minimum log value of the gate for numerical stability. default: -5
 | 
			
		||||
    BT: tl.constexpr,  # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
    p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    mask = (i_k * BK + tl.arange(0, BK)) < DK
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
 | 
			
		||||
    p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
 | 
			
		||||
    p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
 | 
			
		||||
    p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    b_dk = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i in range(BT):
 | 
			
		||||
        _q = tl.load(p_q, mask=mask, other=0)
 | 
			
		||||
        gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        score = tl.math.exp2(gq[None, :] - b_g)
 | 
			
		||||
        score = tl.where(o_i[:, None] <= i, score, 0)
 | 
			
		||||
        _dA = tl.load(p_dA)
 | 
			
		||||
        _dA = tl.where(o_i <= i, _dA, 0)
 | 
			
		||||
        b_dk += (_dA[:, None] * score * _q[None, :])
 | 
			
		||||
        b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0)
 | 
			
		||||
        tl.store(p_dq, b_dq, mask=mask)
 | 
			
		||||
        p_q += DK
 | 
			
		||||
        p_dq += DK
 | 
			
		||||
        p_gq += DK
 | 
			
		||||
        p_dA += BT
 | 
			
		||||
 | 
			
		||||
    p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedChunkGLAFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
 | 
			
		||||
        ctx.g_dtype = g.dtype
 | 
			
		||||
        g_original = g
 | 
			
		||||
        # cumulative decay should be in float32, otherwise the err will be accumulated and amplified.
 | 
			
		||||
        g = torch.empty_like(g, dtype=torch.float32)
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
        ctx.scale = scale
 | 
			
		||||
 | 
			
		||||
        # inter-chunk
 | 
			
		||||
        BT = 16  # chunk_size
 | 
			
		||||
        BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 2
 | 
			
		||||
 | 
			
		||||
        NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
        o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
 | 
			
		||||
        q_g = torch.empty_like(q)
 | 
			
		||||
        k_g = torch.empty_like(k)
 | 
			
		||||
        grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
 | 
			
		||||
        fwd_decay_cumsum[grid](
 | 
			
		||||
            g_original,
 | 
			
		||||
            g,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BT=BT, BK=BK, DK=d_head_qk, num_warps=1
 | 
			
		||||
        )
 | 
			
		||||
        prepare_qg_kg[grid](
 | 
			
		||||
            q, k, g, q_g, k_g,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BT=BT, BK=BK, DK=d_head_qk, num_warps=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if output_final_state:
 | 
			
		||||
            final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float, requires_grad=False)
 | 
			
		||||
        else:
 | 
			
		||||
            final_state = None
 | 
			
		||||
        # the bug still exists even for Triton 2.2 on H100 GPUs
 | 
			
		||||
        # so we always enable initial checks
 | 
			
		||||
        CHECK = True
 | 
			
		||||
        if version.parse(triton.__version__) < version.parse('2.2.0'):
 | 
			
		||||
            import warnings
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "Triton<2.2.0 detected for running this kernel, "
 | 
			
		||||
                "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
 | 
			
		||||
                "that lead to significant precision loss. "
 | 
			
		||||
                "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
 | 
			
		||||
                "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
 | 
			
		||||
            )
 | 
			
		||||
            CHECK = True
 | 
			
		||||
 | 
			
		||||
        grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
        fused_chunk_gla_fwd_kernel[grid](
 | 
			
		||||
            q_g, k_g, v, g, o, initial_state, final_state,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
            STORE_FINAL_STATE=output_final_state,
 | 
			
		||||
            CHECK=CHECK,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        o = o.sum(0)
 | 
			
		||||
 | 
			
		||||
        # intra-chunk
 | 
			
		||||
        chunk_size = 16
 | 
			
		||||
        num_chunk = seq_len // chunk_size
 | 
			
		||||
        v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
 | 
			
		||||
        BK = min(d_head_qk, 64)
 | 
			
		||||
        NK = triton.cdiv(d_head_qk, BK)
 | 
			
		||||
        A = q.new_empty(NK, batch_size, n_heads, triton.cdiv(seq_len, BT), BT, BT)
 | 
			
		||||
        grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
 | 
			
		||||
        fwd_inner_chunk[grid](
 | 
			
		||||
            q, k, g, A,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,  BT=BT, BK=BK, DK=d_head_qk, num_stages=3,
 | 
			
		||||
            num_warps=4
 | 
			
		||||
        )
 | 
			
		||||
        A = A.sum(0)
 | 
			
		||||
        o2 = A @ v2
 | 
			
		||||
        o2 = rearrange(o2, 'b h n c d -> b h (n c) d')
 | 
			
		||||
        # combine inner and inter
 | 
			
		||||
        o.add_(o2)
 | 
			
		||||
        ctx.save_for_backward(q, k, v, g_original, A, initial_state)
 | 
			
		||||
        ctx.CHECK = CHECK
 | 
			
		||||
        return o.to(v), final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    def backward(ctx, do, d_final_state=None):
 | 
			
		||||
        q, k, v, g_origin, A, initial_state = ctx.saved_tensors
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
        scale = ctx.scale
 | 
			
		||||
 | 
			
		||||
        # recomputation
 | 
			
		||||
        # inter-chunk
 | 
			
		||||
        BT = 16  # chunk_size
 | 
			
		||||
        g = torch.empty_like(g_origin, dtype=torch.float32)
 | 
			
		||||
        BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
 | 
			
		||||
        NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
        q_g = torch.empty_like(q)
 | 
			
		||||
        k_g = torch.empty_like(k)
 | 
			
		||||
        grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
 | 
			
		||||
        fwd_decay_cumsum[grid](
 | 
			
		||||
            g_origin,
 | 
			
		||||
            g,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BT=BT, BK=BK, DK=d_head_qk, num_warps=1
 | 
			
		||||
        )
 | 
			
		||||
        prepare_qg_kg[grid](
 | 
			
		||||
            q, k, g, q_g, k_g,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BT=BT, BK=BK, DK=d_head_qk, num_warps=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # inter-chunk
 | 
			
		||||
        BT = 16
 | 
			
		||||
        BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
 | 
			
		||||
        NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 2
 | 
			
		||||
        dq = q.new_empty(NV, batch_size, n_heads,  seq_len, d_head_qk)
 | 
			
		||||
        dk = q.new_empty(NV, batch_size, n_heads,  seq_len, d_head_qk)
 | 
			
		||||
        dv = q.new_empty(NK, batch_size, n_heads,  seq_len, d_head_v)
 | 
			
		||||
 | 
			
		||||
        grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
 | 
			
		||||
        fused_chunk_gla_bwd_kernel[grid](
 | 
			
		||||
            q_g, k_g, v, g, do, dq, dk, dv, initial_state,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            # clamp_min=-3,
 | 
			
		||||
            BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
            CHECK=ctx.CHECK,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages,
 | 
			
		||||
        )
 | 
			
		||||
        dq = dq.sum(0)
 | 
			
		||||
        dk = dk.sum(0)
 | 
			
		||||
        dv = dv.sum(0)
 | 
			
		||||
 | 
			
		||||
        # intra chunk
 | 
			
		||||
        num_chunk = seq_len // BT
 | 
			
		||||
        v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
 | 
			
		||||
        do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk)
 | 
			
		||||
        dA2 = (do2 @ v2.transpose(-2, -1)) * scale
 | 
			
		||||
        dv2 = A.transpose(-1, -2) @ do2
 | 
			
		||||
        dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk)
 | 
			
		||||
 | 
			
		||||
        BK = min(triton.next_power_of_2(d_head_qk), 16)
 | 
			
		||||
        NK = triton.cdiv(d_head_qk, BK)
 | 
			
		||||
        dk2 = torch.empty_like(k)
 | 
			
		||||
        dq2 = torch.empty_like(q)
 | 
			
		||||
 | 
			
		||||
        grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
 | 
			
		||||
        bwd_inner_chunk[grid](
 | 
			
		||||
            q, k, g,
 | 
			
		||||
            dA2, dq2, dk2,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BT=BT, DK=d_head_qk, BK=BK,
 | 
			
		||||
            num_warps=1,
 | 
			
		||||
            num_stages=3
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        BK = min(triton.next_power_of_2(d_head_qk), 32)
 | 
			
		||||
        NK = triton.cdiv(d_head_qk, BK)
 | 
			
		||||
        dg = torch.empty_like(g, dtype=torch.float32)
 | 
			
		||||
        grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
 | 
			
		||||
        bwd_decay_global_cumsum[grid](
 | 
			
		||||
            dq2, dq, dk2, dk, q, k, g, dg,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            BT=BT, DK=d_head_qk, BK=BK,
 | 
			
		||||
            num_warps=1,
 | 
			
		||||
            num_stages=1
 | 
			
		||||
        )
 | 
			
		||||
        dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)
 | 
			
		||||
 | 
			
		||||
        def rev_cumsum_exclusive(x):
 | 
			
		||||
            cumsum_x = x.cumsum(-2)
 | 
			
		||||
            rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x
 | 
			
		||||
            return rev_cumsum_x
 | 
			
		||||
 | 
			
		||||
        rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])
 | 
			
		||||
        dg.add_(rev_cumsum_dg.unsqueeze(-2))
 | 
			
		||||
        dv.add_(dv2)
 | 
			
		||||
        dg = rearrange(dg, 'b h n c d -> b h (n c) d')
 | 
			
		||||
 | 
			
		||||
        return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pad(x, chunk_size=16):
 | 
			
		||||
    seq_len = x.shape[-2]
 | 
			
		||||
    padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size
 | 
			
		||||
    if x.shape[-2] % chunk_size != 0:
 | 
			
		||||
        x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len))
 | 
			
		||||
    
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ceildiv(a, b):
 | 
			
		||||
    return -(a // -b)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fused_chunk_gla(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    g: torch.Tensor,
 | 
			
		||||
    scale: int = -1,
 | 
			
		||||
    initial_state: torch.Tensor = None,
 | 
			
		||||
    output_final_state: bool = False
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    if scale == -1:
 | 
			
		||||
        scale = q.shape[-1] ** -0.5
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = initial_state.detach()
 | 
			
		||||
    seq_len = q.shape[-2]
 | 
			
		||||
    q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
 | 
			
		||||
    o, final_state = FusedChunkGLAFunction.apply(
 | 
			
		||||
        q, k, v, g, scale, initial_state, output_final_state)
 | 
			
		||||
    o = o[..., :seq_len, :]
 | 
			
		||||
    return o, final_state
 | 
			
		||||
							
								
								
									
										138
									
								
								finetune/lora/v6/fla/ops/gla/chunk_util.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								finetune/lora/v6/fla/ops/gla/chunk_util.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,138 @@
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
inv_ln2 = 1.44269504
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fwd_decay_cumsum(
 | 
			
		||||
    g,
 | 
			
		||||
    g_o, 
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    B,
 | 
			
		||||
    H,
 | 
			
		||||
    T,
 | 
			
		||||
    scale,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    DK: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    cum_decay = tl.zeros([BK], dtype=tl.float32)
 | 
			
		||||
    mask = (i_k * BK + tl.arange(0, BK)) < DK
 | 
			
		||||
 | 
			
		||||
    for i in range(BT):
 | 
			
		||||
        _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        cum_decay += _g * inv_ln2
 | 
			
		||||
        tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)
 | 
			
		||||
        p_g += DK
 | 
			
		||||
        p_go += DK
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def prepare_qg_kg(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    g,
 | 
			
		||||
    qg,
 | 
			
		||||
    kg,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    B,
 | 
			
		||||
    H,
 | 
			
		||||
    T,
 | 
			
		||||
    scale,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    DK: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
 | 
			
		||||
    
 | 
			
		||||
    mask = (i_k * BK + tl.arange(0, BK)) < DK
 | 
			
		||||
 | 
			
		||||
    last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))
 | 
			
		||||
 | 
			
		||||
    for i in range(BT):
 | 
			
		||||
        _q = tl.load(p_q, mask=mask, other=0)
 | 
			
		||||
        _k = tl.load(p_k, mask=mask, other=0)
 | 
			
		||||
        _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        _q *= tl.math.exp2(_g) * scale
 | 
			
		||||
        _k *= tl.math.exp2(last_decay - _g)
 | 
			
		||||
        tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)
 | 
			
		||||
        tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)
 | 
			
		||||
        p_q += DK
 | 
			
		||||
        p_g += DK
 | 
			
		||||
        p_k += DK
 | 
			
		||||
        p_kg += DK
 | 
			
		||||
        p_qg += DK
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def bwd_decay_global_cumsum(
 | 
			
		||||
    dq_inner,
 | 
			
		||||
    dq_inter,
 | 
			
		||||
    dk_inner,
 | 
			
		||||
    dk_inter,
 | 
			
		||||
    q, k, g, dg,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    B,
 | 
			
		||||
    H,
 | 
			
		||||
    T,
 | 
			
		||||
    scale,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    DK: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
 | 
			
		||||
    p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
 | 
			
		||||
    p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
 | 
			
		||||
    p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
 | 
			
		||||
    p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
 | 
			
		||||
    p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
 | 
			
		||||
    p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
 | 
			
		||||
    p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
 | 
			
		||||
    cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
 | 
			
		||||
    mask = (i_k * BK + tl.arange(0, BK)) < DK
 | 
			
		||||
    last_g = tl.zeros([BK], dtype=tl.float32)
 | 
			
		||||
    for j in range(BT-1, -1, -1):
 | 
			
		||||
        _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        if j == (BT-1):
 | 
			
		||||
            last_g = _g
 | 
			
		||||
        _dq1 = tl.load(p_dq_inner, mask=mask, other=0)
 | 
			
		||||
        _dq2 = tl.load(p_dq_inter, mask=mask, other=0)
 | 
			
		||||
        _dq2 *= tl.math.exp2(_g)
 | 
			
		||||
        _dq = _dq1 + _dq2
 | 
			
		||||
        tl.store(p_dq_inter, _dq, mask=mask)
 | 
			
		||||
        _dk1 = tl.load(p_dk_inner, mask=mask, other=0)
 | 
			
		||||
        _dk2 = tl.load(p_dk_inter, mask=mask, other=0)
 | 
			
		||||
        _dk2 *= tl.math.exp2(last_g - _g)
 | 
			
		||||
        _dk = _dk1 + _dk2
 | 
			
		||||
        tl.store(p_dk_inter, _dk, mask=mask)
 | 
			
		||||
        _q = tl.load(p_q, mask=mask, other=0)
 | 
			
		||||
        _k = tl.load(p_k, mask=mask, other=0)
 | 
			
		||||
        _dg = _dq * _q - _dk * _k
 | 
			
		||||
        cum_grad_dg += _dg
 | 
			
		||||
        tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
 | 
			
		||||
        p_g -= DK
 | 
			
		||||
        p_k -= DK
 | 
			
		||||
        p_q -= DK
 | 
			
		||||
        p_dq_inner -= DK
 | 
			
		||||
        p_dk_inner -= DK
 | 
			
		||||
        p_dq_inter -= DK
 | 
			
		||||
        p_dk_inter -= DK
 | 
			
		||||
        p_dg -= DK
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										116
									
								
								finetune/lora/v6/fla/ops/gla/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								finetune/lora/v6/fla/ops/gla/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,116 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from fla.ops.gla.recurrent_fuse import fused_recurrent_gla
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ceildiv(a, b):
 | 
			
		||||
    return -(a // -b)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def naive_recurrent_gla(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    gk,
 | 
			
		||||
    initial_state=None,
 | 
			
		||||
    output_final_state=False,
 | 
			
		||||
    causal=True
 | 
			
		||||
):
 | 
			
		||||
    orig_dtype = q.dtype
 | 
			
		||||
    q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk))
 | 
			
		||||
    batch_size, n_heads, seq_len, d_head_k = q.shape
 | 
			
		||||
    _, _, _, d_head_v = v.shape
 | 
			
		||||
    h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
 | 
			
		||||
    o = torch.zeros_like(v)
 | 
			
		||||
    scale = d_head_k ** -0.5
 | 
			
		||||
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        h += initial_state
 | 
			
		||||
 | 
			
		||||
    for i in range(seq_len):
 | 
			
		||||
        q_i = q[:, :, i, :] * scale
 | 
			
		||||
        k_i = k[:, :, i]
 | 
			
		||||
        v_i = v[:, :, i, :]
 | 
			
		||||
        gk_i = gk[:, :, i].exp()
 | 
			
		||||
        kv_i = k_i[..., None] * v_i[..., None, :]
 | 
			
		||||
        h = h * gk_i[..., None] + kv_i
 | 
			
		||||
        o_i = (q_i[..., None] * h).sum(-2)
 | 
			
		||||
        o[:, :, i] = o_i
 | 
			
		||||
 | 
			
		||||
    if causal:
 | 
			
		||||
        return o.to(orig_dtype), h
 | 
			
		||||
    else:
 | 
			
		||||
        o_reverse = torch.zeros_like(v)
 | 
			
		||||
        h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
 | 
			
		||||
        for i in range(seq_len-1, -1, -1):
 | 
			
		||||
            q_i = q[:, :, i, :] * scale
 | 
			
		||||
            k_i = k[:, :, i]
 | 
			
		||||
            v_i = v[:, :, i, :]
 | 
			
		||||
            gk_i = gk[:, :, i].exp()
 | 
			
		||||
            kv_i = k_i[..., None] * v_i[..., None, :]
 | 
			
		||||
            h = h * gk_i[..., None] + kv_i
 | 
			
		||||
            o_i = (q_i[..., None] * h).sum(-2)
 | 
			
		||||
            o_reverse[:, :, i] = o_i
 | 
			
		||||
 | 
			
		||||
        return o, o_reverse
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    B = 4
 | 
			
		||||
    H = 4
 | 
			
		||||
    L = 512
 | 
			
		||||
    D = 128
 | 
			
		||||
    dtype = torch.float32
 | 
			
		||||
    q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
 | 
			
		||||
    k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
 | 
			
		||||
    v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True)
 | 
			
		||||
    g = F.logsigmoid(torch.rand(B, H, L, D)).cuda(
 | 
			
		||||
    ).clamp_min(-1).to(torch.float32).requires_grad_(True)
 | 
			
		||||
 | 
			
		||||
    do = torch.rand_like(v).cuda()
 | 
			
		||||
    do2 = torch.rand_like(v).cuda()
 | 
			
		||||
    intial_state = torch.rand(B, H, D, D).cuda()
 | 
			
		||||
 | 
			
		||||
    ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False)
 | 
			
		||||
 | 
			
		||||
    ref.backward(do, retain_graph=True)
 | 
			
		||||
    ref_rev.backward(do2, retain_graph=True)
 | 
			
		||||
 | 
			
		||||
    ref_dq, q.grad = q.grad.clone(), None
 | 
			
		||||
    ref_dk, k.grad = k.grad.clone(), None
 | 
			
		||||
    ref_dv, v.grad = v.grad.clone(), None
 | 
			
		||||
    ref_dg, g.grad = g.grad.clone(), None
 | 
			
		||||
 | 
			
		||||
    tri, tri_rev = fused_recurrent_gla(
 | 
			
		||||
        q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False)
 | 
			
		||||
    tri.backward(do, retain_graph=True)
 | 
			
		||||
    tri_rev.backward(do2, retain_graph=True)
 | 
			
		||||
    tri_dq, q.grad = q.grad.clone(), None
 | 
			
		||||
    tri_dk, k.grad = k.grad.clone(), None
 | 
			
		||||
    tri_dv, v.grad = v.grad.clone(), None
 | 
			
		||||
    tri_dg, g.grad = g.grad.clone(), None
 | 
			
		||||
 | 
			
		||||
    assert ref.allclose(tri, 0, 1e-5), breakpoint()
 | 
			
		||||
    assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint()
 | 
			
		||||
    assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
 | 
			
		||||
    assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
 | 
			
		||||
    assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
 | 
			
		||||
    assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
 | 
			
		||||
 | 
			
		||||
    # tri = fused_chunk_gla(q, k, v, g)
 | 
			
		||||
    # tri.backward(do, retain_graph=True)
 | 
			
		||||
    # tri_dq, q.grad = q.grad.clone(), None
 | 
			
		||||
    # tri_dk, k.grad = k.grad.clone(), None
 | 
			
		||||
    # tri_dv, v.grad = v.grad.clone(), None
 | 
			
		||||
    # tri_dg, g.grad = g.grad.clone(), None
 | 
			
		||||
 | 
			
		||||
    # assert ref.allclose(tri, 0, 1e-5), breakpoint()
 | 
			
		||||
    # assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
 | 
			
		||||
    # assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
 | 
			
		||||
    # assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
 | 
			
		||||
    # assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
 | 
			
		||||
    # breakpoint()
 | 
			
		||||
    print("Pass")
 | 
			
		||||
							
								
								
									
										404
									
								
								finetune/lora/v6/fla/ops/gla/recurrent_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										404
									
								
								finetune/lora/v6/fla/ops/gla/recurrent_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,404 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2023, Songlin Yang
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
# on-the-fly computation without materializing hidden statets into HBMs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_recurrent_gla_fwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_K]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V]
 | 
			
		||||
    gk,  # log gate [B, H, L, D_head_K]
 | 
			
		||||
    gv,  # log gate [B, H, L, D_head_V]
 | 
			
		||||
    o,  # output [B, H, L, D_head_V]
 | 
			
		||||
    # initial hidden state initialization [B, H, D_head_K, D_head_V]
 | 
			
		||||
    initial_state,
 | 
			
		||||
    final_state,  # final hidden state [B, H, D_head_K, D_head_V]
 | 
			
		||||
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    B,  # batch size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,  # whether to use initial state
 | 
			
		||||
    STORE_FINAL_STATE: tl.constexpr,  # whether to store final state
 | 
			
		||||
    REVERSE: tl.constexpr,  # whether to do autoregressive modeling in the reverse direction
 | 
			
		||||
    USE_GK: tl.constexpr,  # whether to use gk
 | 
			
		||||
    USE_GV: tl.constexpr,  # whether to use gv
 | 
			
		||||
):
 | 
			
		||||
    # indices
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_k * BK + \
 | 
			
		||||
        tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
 | 
			
		||||
    p_k = k + i_bh * s_qk_h + i_k * BK + \
 | 
			
		||||
        tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
 | 
			
		||||
    p_v = v + i_bh * s_vo_h + i_v * BV + \
 | 
			
		||||
        tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
 | 
			
		||||
    p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \
 | 
			
		||||
        tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
 | 
			
		||||
 | 
			
		||||
    if USE_GK:
 | 
			
		||||
        p_gk = gk + i_bh * s_qk_h + i_k * BK + \
 | 
			
		||||
            tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
 | 
			
		||||
    if USE_GV:
 | 
			
		||||
        p_gv = gv + i_bh * s_vo_h + i_v * BV + \
 | 
			
		||||
            tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
 | 
			
		||||
 | 
			
		||||
    mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
 | 
			
		||||
    mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
 | 
			
		||||
 | 
			
		||||
    h = tl.zeros([BV, BK], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    mask_kv = mask_bk[None, :] & mask_bv[:, None]
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_init_s = initial_state + i_bh * DK * DV + \
 | 
			
		||||
            (i_k * BK + tl.arange(0, BK)[None, :]) * \
 | 
			
		||||
            DV + (i_v * BV + tl.arange(0, BV)[:, None])
 | 
			
		||||
        h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    for _ in range(0, T):
 | 
			
		||||
        _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
        _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
            h = h * _gk[None, :]
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
            h = h * _gv[:, None]
 | 
			
		||||
        h += _k[None, :] * _v[:, None]
 | 
			
		||||
        _o = h * _q[None, :]
 | 
			
		||||
        _o = tl.sum(_o, axis=1)
 | 
			
		||||
        tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
 | 
			
		||||
        p_q += -DK if REVERSE else DK
 | 
			
		||||
        p_k += -DK if REVERSE else DK
 | 
			
		||||
        p_o += -DV if REVERSE else DV
 | 
			
		||||
        p_v += -DV if REVERSE else DV
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            p_gk += -DK if REVERSE else DK
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            p_gv += -DV if REVERSE else DV
 | 
			
		||||
 | 
			
		||||
    if STORE_FINAL_STATE:
 | 
			
		||||
        p_final_s = final_state + i_bh * DK * DV + \
 | 
			
		||||
            (i_k * BK + tl.arange(0, BK)[None, :]) * \
 | 
			
		||||
            DV + (i_v * BV + tl.arange(0, BV)[:, None])
 | 
			
		||||
        tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_recurrent_gla_bwd_kernel(
 | 
			
		||||
    # B: batch_size, H: n_heads, T: seq_len, D: d_head
 | 
			
		||||
    # NV: number of split in the V dimension. NK: number of split in the K dimension
 | 
			
		||||
    q,  # query [B, H, L, D_head_K]
 | 
			
		||||
    k,  # key [B, H, L, D_head_V]
 | 
			
		||||
    v,  # value [B, H, L, D_head_V]
 | 
			
		||||
    gk,  # log gate [B, H, L, D_head_K] \alpha
 | 
			
		||||
    gv,  # log gate [B, H, L, D_head_V] \bete
 | 
			
		||||
 | 
			
		||||
    do,  # gradient of output [B, H, L, D_head_V]
 | 
			
		||||
    dq,  # gradient of query [NV, B, H, L, D_head_K]
 | 
			
		||||
    dk,  # gradient of key [NV, B, H, L, D_head_K]
 | 
			
		||||
    dv,  # gradient of value [NK, B, H, L, D_head_V]
 | 
			
		||||
 | 
			
		||||
    # initial hidden state initialization [B, H, D_head_K, D_head_V]
 | 
			
		||||
    initial_state,
 | 
			
		||||
 | 
			
		||||
    s_qk_h,  # stride size: L * D_head_K
 | 
			
		||||
    s_qk_t,  # stride size: D_head_K
 | 
			
		||||
    s_qk_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    s_vo_h,  # stride size: L * D_head_V
 | 
			
		||||
    s_vo_t,  # stride size: D_head_V
 | 
			
		||||
    s_vo_d,  # stride size: 1
 | 
			
		||||
 | 
			
		||||
    B,  # batch_size
 | 
			
		||||
    H,  # n_heads
 | 
			
		||||
    T,  # seq_len
 | 
			
		||||
    scale,  # D_head_K ** -0.5
 | 
			
		||||
    BK: tl.constexpr,  # BLOCK SIZE along the K dimension
 | 
			
		||||
    BV: tl.constexpr,  # BLOCK SIZE along the V dimension
 | 
			
		||||
    DK: tl.constexpr,  # D_head_K
 | 
			
		||||
    DV: tl.constexpr,  # D_head_V
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,  # whether to use initial state
 | 
			
		||||
    REVERSE: tl.constexpr,  # whether to do autoregressive modeling in the reverse direction
 | 
			
		||||
    USE_GK: tl.constexpr,  # whether to use gk
 | 
			
		||||
    USE_GV: tl.constexpr,  # whether to use gv
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_k * BK + \
 | 
			
		||||
        tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
 | 
			
		||||
    p_k = k + i_bh * s_qk_h + i_k * BK + \
 | 
			
		||||
        tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
 | 
			
		||||
    p_v = v + i_bh * s_vo_h + i_v * BV + \
 | 
			
		||||
        tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
 | 
			
		||||
    p_do = do + i_bh * s_vo_h + i_v * BV + \
 | 
			
		||||
        tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
 | 
			
		||||
    p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \
 | 
			
		||||
        tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
 | 
			
		||||
    if USE_GK:
 | 
			
		||||
        p_gk = gk + i_bh * s_qk_h + i_k * BK + \
 | 
			
		||||
            tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
 | 
			
		||||
    if USE_GV:
 | 
			
		||||
        p_gv = gv + i_bh * s_vo_h + i_v * BV + \
 | 
			
		||||
            tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
 | 
			
		||||
    mask_bk = i_k * BK + tl.arange(0, BK) < DK
 | 
			
		||||
    mask_bv = i_v * BV + tl.arange(0, BV) < DV
 | 
			
		||||
    mask_kv = mask_bk[:, None] & mask_bv[None, :]
 | 
			
		||||
    h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_init_s = initial_state + i_bh * DK * DV + \
 | 
			
		||||
            (i_k * BK + tl.arange(0, BK)[:, None]) * \
 | 
			
		||||
            DV + (i_v * BV + tl.arange(0, BV)[None, :])
 | 
			
		||||
        h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i in range(0, T):
 | 
			
		||||
        _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
        _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
            h = h * _gk[:, None]
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
            h = h * _gv[None, :]
 | 
			
		||||
        h += _k[:, None] * _v[None, :]
 | 
			
		||||
        _d_q = h * _do[None, :]
 | 
			
		||||
        d_q = tl.sum(_d_q, axis=1) * scale
 | 
			
		||||
        tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
 | 
			
		||||
 | 
			
		||||
        p_k += -DK if REVERSE else DK
 | 
			
		||||
        p_v += -DV if REVERSE else DV
 | 
			
		||||
        p_q += -DK if REVERSE else DK
 | 
			
		||||
        p_do += -DV if REVERSE else DV
 | 
			
		||||
        p_dq += -DK if REVERSE else DK
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            p_gk += -DK if REVERSE else DK
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            p_gv += -DV if REVERSE else DV
 | 
			
		||||
 | 
			
		||||
    # sync threads
 | 
			
		||||
    tl.debug_barrier()
 | 
			
		||||
 | 
			
		||||
    p_q = q + i_bh * s_qk_h + i_k * BK + \
 | 
			
		||||
        tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
 | 
			
		||||
    p_k = k + i_bh * s_qk_h + i_k * BK + \
 | 
			
		||||
        tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
 | 
			
		||||
    p_do = do + i_bh * s_vo_h + i_v * BV + \
 | 
			
		||||
        tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
 | 
			
		||||
    p_v = v + i_bh * s_vo_h + i_v * BV + \
 | 
			
		||||
        tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
 | 
			
		||||
    p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
 | 
			
		||||
        BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
 | 
			
		||||
    p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
 | 
			
		||||
        BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
 | 
			
		||||
    if USE_GK:
 | 
			
		||||
        p_gk = gk + i_bh * s_qk_h + i_k * BK + \
 | 
			
		||||
            tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
 | 
			
		||||
    if USE_GV:
 | 
			
		||||
        p_gv = gv + i_bh * s_vo_h + i_v * BV + \
 | 
			
		||||
            tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
 | 
			
		||||
 | 
			
		||||
    d_h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    for _ in range(T):
 | 
			
		||||
        _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
 | 
			
		||||
        _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
        _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
        d_h += _q[:, None] * _do[None, :]
 | 
			
		||||
        d_k = tl.sum(d_h * _v[None, :], axis=1)
 | 
			
		||||
        d_v = tl.sum(d_h * _k[:, None], axis=0)
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
 | 
			
		||||
            d_h *= _gk[:, None]
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
 | 
			
		||||
            d_h *= _gv[None, :]
 | 
			
		||||
        tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
 | 
			
		||||
        tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
 | 
			
		||||
 | 
			
		||||
        p_do += DV if REVERSE else -DV
 | 
			
		||||
        p_q += DK if REVERSE else -DK
 | 
			
		||||
        p_k += DK if REVERSE else -DK
 | 
			
		||||
        p_v += DV if REVERSE else -DV
 | 
			
		||||
        p_dk += DK if REVERSE else -DK
 | 
			
		||||
        p_dv += DV if REVERSE else -DV
 | 
			
		||||
        if USE_GK:
 | 
			
		||||
            p_gk += DK if REVERSE else -DK
 | 
			
		||||
        if USE_GV:
 | 
			
		||||
            p_gv += DV if REVERSE else -DV
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedRecurrentGLAFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
        # default scale
 | 
			
		||||
        if scale is None:
 | 
			
		||||
            scale = d_head_qk ** -0.5
 | 
			
		||||
        if gk is not None:
 | 
			
		||||
            gk = gk.float().exp()
 | 
			
		||||
        if gv is not None:
 | 
			
		||||
            gv = gv.float().exp()
 | 
			
		||||
 | 
			
		||||
        BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
 | 
			
		||||
        NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 1
 | 
			
		||||
 | 
			
		||||
        o = q.new_empty(NK, batch_size, n_heads, seq_len,
 | 
			
		||||
                        d_head_v, dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
        if output_final_state:
 | 
			
		||||
            final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
 | 
			
		||||
        else:
 | 
			
		||||
            final_state = None
 | 
			
		||||
 | 
			
		||||
        grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
        fused_recurrent_gla_fwd_kernel[grid](
 | 
			
		||||
            q, k, v, gk, gv, o, initial_state, final_state,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
            STORE_FINAL_STATE=final_state is not None,
 | 
			
		||||
            USE_GK=gk is not None,
 | 
			
		||||
            USE_GV=gv is not None,
 | 
			
		||||
            REVERSE=reverse,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        o = o.sum(0)
 | 
			
		||||
        ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)
 | 
			
		||||
        ctx.scale = scale
 | 
			
		||||
        ctx.reverse = reverse
 | 
			
		||||
        # we do not need the gradient of the final state from the next chunk
 | 
			
		||||
        # similiar to Trunctated BPTT
 | 
			
		||||
        if final_state is not None:
 | 
			
		||||
            final_state = final_state.detach()
 | 
			
		||||
        return o.to(q.dtype), final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    def backward(ctx, do, d_final_state=None):
 | 
			
		||||
        q, k, v, gk, gv, initial_state, o = ctx.saved_tensors
 | 
			
		||||
        batch_size, n_heads, seq_len, d_head_qk = q.shape
 | 
			
		||||
        d_head_v = v.shape[-1]
 | 
			
		||||
        scale = ctx.scale
 | 
			
		||||
 | 
			
		||||
        BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
 | 
			
		||||
        NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 1
 | 
			
		||||
 | 
			
		||||
        dq = q.new_empty(NV, batch_size, n_heads,  seq_len,
 | 
			
		||||
                         d_head_qk, dtype=torch.float32)
 | 
			
		||||
        dk = q.new_empty(NV, batch_size, n_heads,  seq_len,
 | 
			
		||||
                         d_head_qk, dtype=torch.float32)
 | 
			
		||||
        dv = q.new_empty(NK, batch_size, n_heads, seq_len,
 | 
			
		||||
                         d_head_v, dtype=torch.float32)
 | 
			
		||||
        grid = (NV, NK, batch_size * n_heads)
 | 
			
		||||
 | 
			
		||||
        fused_recurrent_gla_bwd_kernel[grid](
 | 
			
		||||
            q, k, v, gk, gv, do, dq, dk, dv, initial_state,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            batch_size, n_heads, seq_len, scale,
 | 
			
		||||
            DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
            REVERSE=ctx.reverse,
 | 
			
		||||
            USE_GK=gk is not None,
 | 
			
		||||
            USE_GV=gv is not None
 | 
			
		||||
        )
 | 
			
		||||
        dq = dq.sum(0)
 | 
			
		||||
        dk = dk.sum(0)
 | 
			
		||||
        dv = dv.sum(0)
 | 
			
		||||
        if gk is not None:
 | 
			
		||||
            _dgk = dq * q.float() - dk * k.float()
 | 
			
		||||
            if ctx.reverse:
 | 
			
		||||
                dgk = _dgk.cumsum(-2)
 | 
			
		||||
            else:
 | 
			
		||||
                _dgk_cumsum = _dgk.cumsum(-2)
 | 
			
		||||
                dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum
 | 
			
		||||
        else:
 | 
			
		||||
            dgk = None
 | 
			
		||||
 | 
			
		||||
        if gv is not None:
 | 
			
		||||
            _dgv = do.float() * o.float() - dv * v.float()
 | 
			
		||||
            if ctx.reverse:
 | 
			
		||||
                dgv = _dgv.cumsum(-2)
 | 
			
		||||
            else:
 | 
			
		||||
                _dgv_cumsum = _dgv.cumsum(-2)
 | 
			
		||||
                dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum
 | 
			
		||||
        else:
 | 
			
		||||
            dgv = None
 | 
			
		||||
 | 
			
		||||
        return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# if scale is None, use d_head_qk ** -0.5 by default. Otherwise specify the scale yourself. e.g. scale = 1.0
 | 
			
		||||
def fused_recurrent_gla(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    gk: torch.Tensor = None,
 | 
			
		||||
    gv: torch.Tensor = None,
 | 
			
		||||
    scale: int = -1,
 | 
			
		||||
    initial_state: torch.Tensor = None,
 | 
			
		||||
    output_final_state: bool = False,
 | 
			
		||||
    causal: bool = True
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    if scale == -1:
 | 
			
		||||
        scale = q.shape[-1] ** -0.5
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = initial_state.detach()
 | 
			
		||||
    if causal:
 | 
			
		||||
        o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)
 | 
			
		||||
        return o, final_state
 | 
			
		||||
    else:
 | 
			
		||||
        # do not support initial_state yet. looks very strange for bidirectional modeling
 | 
			
		||||
        assert initial_state is None
 | 
			
		||||
        assert output_final_state is False
 | 
			
		||||
        o, final_state = FusedRecurrentGLAFunction.apply(
 | 
			
		||||
            q, k, v, gk, gv, scale, initial_state, output_final_state, False)
 | 
			
		||||
        o_reversed, final_state = FusedRecurrentGLAFunction.apply(
 | 
			
		||||
            q, k, v, gk, gv, scale, initial_state, output_final_state, True)
 | 
			
		||||
        return [o, o_reversed]
 | 
			
		||||
							
								
								
									
										9
									
								
								finetune/lora/v6/fla/ops/hgrn/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								finetune/lora/v6/fla/ops/hgrn/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .chunk import chunk_hgrn
 | 
			
		||||
from .recurrent_fuse import fused_recurrent_hgrn
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'chunk_hgrn',
 | 
			
		||||
    'fused_recurrent_hgrn'
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										373
									
								
								finetune/lora/v6/fla/ops/hgrn/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										373
									
								
								finetune/lora/v6/fla/ops/hgrn/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,373 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2024, Yu Zhang, Songlin Yang
 | 
			
		||||
 | 
			
		||||
# this function implements the chunkwise form of HGRN, inspired by
 | 
			
		||||
# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html)
 | 
			
		||||
# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan
 | 
			
		||||
 | 
			
		||||
# from tests on H800, with B, H, D = 16, 4, 128, we see that the chunk can be greatly faster than the recurrent:
 | 
			
		||||
#
 | 
			
		||||
# Performance:
 | 
			
		||||
#    seq_len     chunk  recurrent  chunk_bwd  recurrent_bwd
 | 
			
		||||
# 0    128.0  0.039360   0.061056   0.312160       0.205008
 | 
			
		||||
# 1    256.0  0.045824   0.123712   0.308784       0.297696
 | 
			
		||||
# 2    512.0  0.058688   0.241952   0.310720       0.626528
 | 
			
		||||
# 3   1024.0  0.088288   0.476992   0.313184       1.333152
 | 
			
		||||
# 4   2048.0  0.169472   0.943264   0.452464       2.724864
 | 
			
		||||
# 5   4096.0  0.329920   1.886144   0.881600       5.551520
 | 
			
		||||
# 6   8192.0  0.647872   3.755040   1.740496      11.117184
 | 
			
		||||
# 7  16384.0  1.272064   7.520576   3.446608      22.362528
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=8),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=8),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=8),
 | 
			
		||||
    ],
 | 
			
		||||
    key=['D']
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_hgrn_fwd_kernel_h(
 | 
			
		||||
    x,
 | 
			
		||||
    g,
 | 
			
		||||
    gc,
 | 
			
		||||
    o,
 | 
			
		||||
    h0,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    D: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BD: tl.constexpr,
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    o_d = i_d * BD + tl.arange(0, BD)
 | 
			
		||||
    mask = o_d < D
 | 
			
		||||
 | 
			
		||||
    p_x = x + i_bh * T * D + i_t * BT * D + o_d
 | 
			
		||||
    p_g = g + i_bh * T * D + i_t * BT * D + o_d
 | 
			
		||||
    p_gc = gc + i_bh * T * D + i_t * BT * D + o_d
 | 
			
		||||
    p_o = o + i_bh * T * D + i_t * BT * D + o_d
 | 
			
		||||
 | 
			
		||||
    b_h = tl.zeros([BD], dtype=tl.float32)
 | 
			
		||||
    b_gc = tl.zeros([BD], dtype=tl.float32)
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        if i_t == 0:
 | 
			
		||||
            b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
    for i in range(0, BT):
 | 
			
		||||
        mask_t = mask & ((i_t * BT + i) < T)
 | 
			
		||||
        b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)
 | 
			
		||||
        b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)
 | 
			
		||||
        b_h = tl.exp(b_g) * b_h + b_x
 | 
			
		||||
        b_gc = b_gc + b_g
 | 
			
		||||
        tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)
 | 
			
		||||
        tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)
 | 
			
		||||
 | 
			
		||||
        p_x += D
 | 
			
		||||
        p_g += D
 | 
			
		||||
        p_gc += D
 | 
			
		||||
        p_o += D
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_hgrn_fwd_kernel_o(
 | 
			
		||||
    gc,
 | 
			
		||||
    o,
 | 
			
		||||
    s_h,
 | 
			
		||||
    s_t,
 | 
			
		||||
    s_d,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    D: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BD: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_d, i_bh = tl.program_id(0), tl.program_id(1)
 | 
			
		||||
    o_d = i_d * BD + tl.arange(0, BD)
 | 
			
		||||
    mask = o_d < D
 | 
			
		||||
 | 
			
		||||
    for i_t in range(1, tl.cdiv(T, BT)):
 | 
			
		||||
        p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
 | 
			
		||||
        p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
 | 
			
		||||
 | 
			
		||||
        # [BD,]
 | 
			
		||||
        b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        # [BT, BD]
 | 
			
		||||
        b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
        b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
        b_o = b_o + tl.exp(b_gc) * b_h0[None, :]
 | 
			
		||||
        tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=8),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=8),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=8),
 | 
			
		||||
    ],
 | 
			
		||||
    key=['D']
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_hgrn_bwd_kernel_h(
 | 
			
		||||
    g,
 | 
			
		||||
    gc,
 | 
			
		||||
    dx,
 | 
			
		||||
    do,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    D: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BD: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    o_d = i_d * BD + tl.arange(0, BD)
 | 
			
		||||
    mask = o_d < D
 | 
			
		||||
    BC = min(BT, T - i_t * BT)
 | 
			
		||||
    NT = tl.num_programs(1)
 | 
			
		||||
 | 
			
		||||
    p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d
 | 
			
		||||
    p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d
 | 
			
		||||
    p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d
 | 
			
		||||
    p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d
 | 
			
		||||
 | 
			
		||||
    if i_t == NT - 1:
 | 
			
		||||
        b_gc = tl.zeros([BD], dtype=tl.float32)
 | 
			
		||||
    else:
 | 
			
		||||
        b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
    b_dh = tl.zeros([BD], dtype=tl.float32)
 | 
			
		||||
    for _ in range(BC - 1, -1, -1):
 | 
			
		||||
        tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)
 | 
			
		||||
 | 
			
		||||
        b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
        b_gc = b_gc + b_g
 | 
			
		||||
        b_dh = b_dh + b_do
 | 
			
		||||
        b_dx = b_dh
 | 
			
		||||
        b_dh = b_dh * tl.exp(b_g)
 | 
			
		||||
 | 
			
		||||
        tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
 | 
			
		||||
 | 
			
		||||
        p_g -= D
 | 
			
		||||
        p_gc -= D
 | 
			
		||||
        p_dx -= D
 | 
			
		||||
        p_do -= D
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_hgrn_bwd_kernel_o(
 | 
			
		||||
    g,
 | 
			
		||||
    gc,
 | 
			
		||||
    o,
 | 
			
		||||
    dx,
 | 
			
		||||
    dg,
 | 
			
		||||
    s_h,
 | 
			
		||||
    s_t,
 | 
			
		||||
    s_d,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    D: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BD: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_d, i_bh = tl.program_id(0), tl.program_id(1)
 | 
			
		||||
    o_d = i_d * BD + tl.arange(0, BD)
 | 
			
		||||
    mask = o_d < D
 | 
			
		||||
 | 
			
		||||
    for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
 | 
			
		||||
        p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
 | 
			
		||||
        p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
 | 
			
		||||
        p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))
 | 
			
		||||
        p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
 | 
			
		||||
        p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
 | 
			
		||||
 | 
			
		||||
        # [BD,]
 | 
			
		||||
        mask_t = mask & ((i_t + 1) * BT < T)
 | 
			
		||||
        b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)
 | 
			
		||||
        # [BT, BD]
 | 
			
		||||
        b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
        b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
        b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
        b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
        b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
        b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]
 | 
			
		||||
        b_dg = b_o * b_dx * tl.exp(b_g)
 | 
			
		||||
        tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChunkHGRNFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(ctx, x, g, initial_state=None, output_final_state=False):
 | 
			
		||||
        B, H, T, D = x.shape
 | 
			
		||||
        BT, BD = 128, min(64, triton.next_power_of_2(D))
 | 
			
		||||
        num_warps = 8 if BD == 64 else 4
 | 
			
		||||
 | 
			
		||||
        gc = torch.empty_like(g, dtype=torch.float)
 | 
			
		||||
        o = torch.empty_like(x, dtype=torch.float)
 | 
			
		||||
        def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
 | 
			
		||||
        chunk_hgrn_fwd_kernel_h[grid](
 | 
			
		||||
            x, g, gc, o, initial_state,
 | 
			
		||||
            T, D,
 | 
			
		||||
            BT=BT,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None
 | 
			
		||||
        )
 | 
			
		||||
        def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
 | 
			
		||||
        chunk_hgrn_fwd_kernel_o[grid](
 | 
			
		||||
            gc, o,
 | 
			
		||||
            o.stride(1), o.stride(2), o.stride(3),
 | 
			
		||||
            T, D,
 | 
			
		||||
            BT=BT, BD=BD,
 | 
			
		||||
            num_warps=num_warps
 | 
			
		||||
        )
 | 
			
		||||
        final_state = None
 | 
			
		||||
        if output_final_state:
 | 
			
		||||
            final_state = o[:, :, -1].clone()
 | 
			
		||||
        o = o.to(x.dtype)
 | 
			
		||||
        ctx.save_for_backward(g, o, initial_state)
 | 
			
		||||
        return o, final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, do, dht=None):
 | 
			
		||||
        g, o, initial_state = ctx.saved_tensors
 | 
			
		||||
        B, H, T, D = do.shape
 | 
			
		||||
        BT, BD = 128, min(64, triton.next_power_of_2(D))
 | 
			
		||||
        num_warps = 8 if BD == 64 else 4
 | 
			
		||||
 | 
			
		||||
        gc = torch.empty_like(g, dtype=torch.float)
 | 
			
		||||
        dx = torch.empty_like(o)
 | 
			
		||||
        dg = torch.empty_like(g)
 | 
			
		||||
        def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
 | 
			
		||||
        chunk_hgrn_bwd_kernel_h[grid](
 | 
			
		||||
            g, gc, dx, do,
 | 
			
		||||
            T, D,
 | 
			
		||||
            BT=BT
 | 
			
		||||
        )
 | 
			
		||||
        def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
 | 
			
		||||
        chunk_hgrn_bwd_kernel_o[grid](
 | 
			
		||||
            g, gc, o, dx, dg,
 | 
			
		||||
            o.stride(1), o.stride(2), o.stride(3),
 | 
			
		||||
            T, D,
 | 
			
		||||
            BT=BT, BD=BD,
 | 
			
		||||
            num_warps=num_warps
 | 
			
		||||
        )
 | 
			
		||||
        if initial_state is not None:
 | 
			
		||||
            dg[:, :, 0] = initial_state * dx[:, :, 0] * g[:, :, 0].exp()
 | 
			
		||||
 | 
			
		||||
        return dx, dg, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chunk_hgrn(
 | 
			
		||||
    x: torch.Tensor,
 | 
			
		||||
    g: torch.Tensor,
 | 
			
		||||
    initial_state: torch.Tensor = None,
 | 
			
		||||
    output_final_state: bool = False
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = initial_state.detach()
 | 
			
		||||
    o, final_state = ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)
 | 
			
		||||
    return o, final_state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
    from fla.ops.hgrn.naive import naive_recurrent_hgrn
 | 
			
		||||
    from fla.ops.hgrn.recurrent_fuse import fused_recurrent_hgrn
 | 
			
		||||
    B, H, T, D = 8, 4, 512, 128
 | 
			
		||||
    dtype = torch.bfloat16
 | 
			
		||||
    torch.manual_seed(42)
 | 
			
		||||
    # [batch_size, n_heads, seq_len, d_head]
 | 
			
		||||
    x = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
 | 
			
		||||
    g = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
 | 
			
		||||
    x, g = (1 - g.sigmoid()) * x, F.logsigmoid(g)
 | 
			
		||||
    print(f'x:\t{float(x.min()):>10.6f}\t{float(x.max()):>10.6f}')
 | 
			
		||||
    print(f'g:\t{float(g.min()):>10.6f}\t{float(g.max()):>10.6f}')
 | 
			
		||||
    x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
 | 
			
		||||
    print(f"DTYPE:\t{x.dtype}")
 | 
			
		||||
    do = torch.randn_like(x)
 | 
			
		||||
    h0 = torch.randn_like(x[:, :, 0])
 | 
			
		||||
    ref, ref_ht = naive_recurrent_hgrn(x, g, h0, output_final_state=True)
 | 
			
		||||
    ref.backward(do)
 | 
			
		||||
    ref_dx, x.grad = x.grad.clone(), None
 | 
			
		||||
    ref_dg, g.grad = g.grad.clone(), None
 | 
			
		||||
 | 
			
		||||
    tri, tri_ht = fused_recurrent_hgrn(x, g, h0, output_final_state=True)
 | 
			
		||||
    tri.backward(do)
 | 
			
		||||
    tri_dx, x.grad = x.grad.clone(), None
 | 
			
		||||
    tri_dg, g.grad = g.grad.clone(), None
 | 
			
		||||
    print("  \t    DIFF\t    MAX")
 | 
			
		||||
    print(' o\t', f"{float((ref - tri).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
 | 
			
		||||
    print('ht\t', f"{float((ref_ht[0] - tri_ht[0]).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
 | 
			
		||||
    print('dx\t', f"{float((ref_dx - tri_dx).abs().max()):>10.6f}\t{float(ref_dx.max()):>10.6f}")
 | 
			
		||||
    print('dg\t', f"{float((ref_dg - tri_dg).abs().max()):>10.6f}\t{float(ref_dg.max()):>10.6f}")
 | 
			
		||||
    print('Done!')
 | 
			
		||||
 | 
			
		||||
    @triton.testing.perf_report(
 | 
			
		||||
        triton.testing.Benchmark(
 | 
			
		||||
            # argument names to use as an x-axis for the plot
 | 
			
		||||
            x_names=['seq_len'],
 | 
			
		||||
            # different possible values for `x_name`
 | 
			
		||||
            x_vals=[128 * 2 ** i for i in range(0, 8)],
 | 
			
		||||
            # argument name whose value corresponds to a different line in the plot
 | 
			
		||||
            line_arg='provider',
 | 
			
		||||
            # possible values for `line_arg``
 | 
			
		||||
            line_vals=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
 | 
			
		||||
            # label name for the lines
 | 
			
		||||
            line_names=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
 | 
			
		||||
            # line styles
 | 
			
		||||
            styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
 | 
			
		||||
            ylabel="Execution Time (ms)",  # label name for the y-axis
 | 
			
		||||
            # name for the plot. Used also as a file name for saving the plot.
 | 
			
		||||
            plot_name="Performance",
 | 
			
		||||
            args={},
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    def benchmark(seq_len, provider):
 | 
			
		||||
        dtype = torch.bfloat16
 | 
			
		||||
        B, H, D = 16, 4, 128
 | 
			
		||||
 | 
			
		||||
        x = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda')
 | 
			
		||||
        g = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda').sigmoid()
 | 
			
		||||
        x = (1 - g) * x
 | 
			
		||||
        x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
 | 
			
		||||
        do = torch.randn_like(x, dtype=dtype)
 | 
			
		||||
        quantiles = [0.5, 0.2, 0.8]
 | 
			
		||||
        results = 0, 0, 0
 | 
			
		||||
        if provider == 'chunk':
 | 
			
		||||
            results = triton.testing.do_bench(lambda: chunk_hgrn(x, g), quantiles=quantiles)
 | 
			
		||||
        if provider == 'recurrent':
 | 
			
		||||
            results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g), quantiles=quantiles)
 | 
			
		||||
        if provider == 'chunk_bwd':
 | 
			
		||||
            results = triton.testing.do_bench(lambda: chunk_hgrn(x, g)[0].backward(do), quantiles=quantiles)
 | 
			
		||||
        if provider == 'recurrent_bwd':
 | 
			
		||||
            results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g)[0].backward(do), quantiles=quantiles)
 | 
			
		||||
        return results
 | 
			
		||||
    benchmark.run(print_data=True)
 | 
			
		||||
							
								
								
									
										31
									
								
								finetune/lora/v6/fla/ops/hgrn/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								finetune/lora/v6/fla/ops/hgrn/naive.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,31 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def naive_recurrent_hgrn(
 | 
			
		||||
    x: torch.Tensor,
 | 
			
		||||
    g: torch.Tensor,
 | 
			
		||||
    initial_state: Optional[torch.Tensor] = None,
 | 
			
		||||
    output_final_state: Optional[bool] = False
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    dtype = x.dtype
 | 
			
		||||
    x, g = map(lambda i: i.float(), (x, g))
 | 
			
		||||
    B, H, T, D = x.shape
 | 
			
		||||
 | 
			
		||||
    h = torch.zeros(B, H, D, dtype=torch.float, device=x.device)
 | 
			
		||||
    o = torch.zeros_like(x)
 | 
			
		||||
 | 
			
		||||
    final_state = None
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        h += initial_state.detach()
 | 
			
		||||
 | 
			
		||||
    for i in range(T):
 | 
			
		||||
        h = g[:, :, i].exp() * h + x[:, :, i]
 | 
			
		||||
        o[:, :, i] = h
 | 
			
		||||
 | 
			
		||||
    if output_final_state:
 | 
			
		||||
        final_state = h
 | 
			
		||||
    return o.to(dtype), final_state
 | 
			
		||||
							
								
								
									
										185
									
								
								finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,185 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
# Copyright (c) 2023, Songlin Yang
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=8),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=8),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=8),
 | 
			
		||||
    ],
 | 
			
		||||
    key=['D']
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_recurrent_hgrn_fwd_kernel(
 | 
			
		||||
    x,
 | 
			
		||||
    g,
 | 
			
		||||
    o,
 | 
			
		||||
    h0,
 | 
			
		||||
    ht,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    D: tl.constexpr,
 | 
			
		||||
    BD: tl.constexpr,
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    STORE_FINAL_STATE: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_d, i_bh = tl.program_id(0), tl.program_id(1)
 | 
			
		||||
    o_d = i_d * BD + tl.arange(0, BD)
 | 
			
		||||
    mask = o_d < D
 | 
			
		||||
 | 
			
		||||
    p_x = x + i_bh * T * D + o_d
 | 
			
		||||
    p_g = g + i_bh * T * D + o_d
 | 
			
		||||
    p_o = o + i_bh * T * D + o_d
 | 
			
		||||
 | 
			
		||||
    b_h = tl.zeros([BD], dtype=tl.float32)
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h0 = h0 + i_bh * D + o_d
 | 
			
		||||
        b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
    for _ in range(0, T):
 | 
			
		||||
        b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        b_h = tl.exp(b_g) * b_h + b_x
 | 
			
		||||
        tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
 | 
			
		||||
 | 
			
		||||
        p_x += D
 | 
			
		||||
        p_g += D
 | 
			
		||||
        p_o += D
 | 
			
		||||
 | 
			
		||||
    if STORE_FINAL_STATE:
 | 
			
		||||
        p_ht = ht + i_bh * D + o_d
 | 
			
		||||
        tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.autotune(
 | 
			
		||||
    configs=[
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 32}, num_warps=8),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 64}, num_warps=8),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=1),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=2),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=4),
 | 
			
		||||
        triton.Config({'BD': 128}, num_warps=8),
 | 
			
		||||
    ],
 | 
			
		||||
    key=['D']
 | 
			
		||||
)
 | 
			
		||||
@triton.jit
 | 
			
		||||
def fused_recurrent_hgrn_bwd_kernel(
 | 
			
		||||
    g,
 | 
			
		||||
    o,
 | 
			
		||||
    dx,
 | 
			
		||||
    dg,
 | 
			
		||||
    do,
 | 
			
		||||
    h0,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    D: tl.constexpr,
 | 
			
		||||
    BD: tl.constexpr,
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_d, i_bh = tl.program_id(0), tl.program_id(1)
 | 
			
		||||
    o_d = i_d * BD + tl.arange(0, BD)
 | 
			
		||||
    mask = o_d < D
 | 
			
		||||
 | 
			
		||||
    p_g = g + (i_bh * T + T - 1) * D + o_d
 | 
			
		||||
    p_o = o + (i_bh * T + T - 2) * D + o_d
 | 
			
		||||
    p_dx = dx + (i_bh * T + T - 1) * D + o_d
 | 
			
		||||
    p_dg = dg + (i_bh * T + T - 1) * D + o_d
 | 
			
		||||
    p_do = do + (i_bh * T + T - 1) * D + o_d
 | 
			
		||||
 | 
			
		||||
    b_dh = tl.zeros([BD], dtype=tl.float32)
 | 
			
		||||
    for i in range(T - 1, -1, -1):
 | 
			
		||||
        b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        if i > 0:
 | 
			
		||||
            b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        elif USE_INITIAL_STATE:
 | 
			
		||||
            b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
 | 
			
		||||
        else:
 | 
			
		||||
            b_o = tl.zeros([BD], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
        b_dh = b_dh + b_do
 | 
			
		||||
        b_dx = b_dh
 | 
			
		||||
        b_dh = b_dh * tl.exp(b_g)
 | 
			
		||||
        b_dg = b_dh * b_o
 | 
			
		||||
        tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
 | 
			
		||||
        tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
 | 
			
		||||
 | 
			
		||||
        p_g -= D
 | 
			
		||||
        p_o -= D
 | 
			
		||||
        p_dx -= D
 | 
			
		||||
        p_dg -= D
 | 
			
		||||
        p_do -= D
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedRecurrentHGRNFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(ctx, x, g, initial_state=None, output_final_state=False):
 | 
			
		||||
        B, H, T, D = x.shape
 | 
			
		||||
 | 
			
		||||
        final_state = None
 | 
			
		||||
        if output_final_state:
 | 
			
		||||
            final_state = x.new_empty(B, H, D)
 | 
			
		||||
 | 
			
		||||
        o = torch.empty_like(x)
 | 
			
		||||
        def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
 | 
			
		||||
        fused_recurrent_hgrn_fwd_kernel[grid](
 | 
			
		||||
            x, g, o, initial_state, final_state,
 | 
			
		||||
            T, D,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
            STORE_FINAL_STATE=final_state is not None
 | 
			
		||||
        )
 | 
			
		||||
        ctx.save_for_backward(g, o, initial_state)
 | 
			
		||||
        return o, final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, do, dht=None):
 | 
			
		||||
        g, o, initial_state = ctx.saved_tensors
 | 
			
		||||
        B, H, T, D = do.shape
 | 
			
		||||
 | 
			
		||||
        dx = torch.empty_like(o)
 | 
			
		||||
        dg = torch.empty_like(g)
 | 
			
		||||
        def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
 | 
			
		||||
        fused_recurrent_hgrn_bwd_kernel[grid](
 | 
			
		||||
            g, o, dx, dg, do, initial_state,
 | 
			
		||||
            T, D,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return dx, dg, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fused_recurrent_hgrn(
 | 
			
		||||
    x: torch.Tensor,
 | 
			
		||||
    g: torch.Tensor,
 | 
			
		||||
    initial_state: torch.Tensor = None,
 | 
			
		||||
    output_final_state: bool = False
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = initial_state.detach()
 | 
			
		||||
    o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)
 | 
			
		||||
    return o, final_state
 | 
			
		||||
							
								
								
									
										12
									
								
								finetune/lora/v6/fla/ops/linear_attn/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								finetune/lora/v6/fla/ops/linear_attn/__init__.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
from .chunk import chunk_linear_attn
 | 
			
		||||
from .chunk_fuse import fused_chunk_linear_attn
 | 
			
		||||
from .recurrent_fuse import fused_recurrent_linear_attn
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'chunk_linear_attn',
 | 
			
		||||
    'fused_chunk_linear_attn',
 | 
			
		||||
    'fused_recurrent_linear_attn'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										359
									
								
								finetune/lora/v6/fla/ops/linear_attn/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										359
									
								
								finetune/lora/v6/fla/ops/linear_attn/chunk.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,359 @@
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
import triton.language as tl
 | 
			
		||||
from torch.cuda.amp import custom_bwd, custom_fwd
 | 
			
		||||
 | 
			
		||||
from fla.utils import contiguous
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def normalize_output(q, k, o):
 | 
			
		||||
    k = k.transpose(-2, -1)
 | 
			
		||||
    k = k.cumsum(-1)
 | 
			
		||||
    k = k.transpose(-2, -1)
 | 
			
		||||
    z = (q * k).sum(-1, keepdim=True)
 | 
			
		||||
    return o / (z + 1e-5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_linear_attn_fwd_kernel_h(
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    h,
 | 
			
		||||
    initial_state,  # initial state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
    final_state,  # final state of the chunk [B, H, D_head_K, D_head_V]
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    NT: tl.constexpr,
 | 
			
		||||
    USE_INITIAL_STATE: tl.constexpr,
 | 
			
		||||
    STORE_FINAL_STATE: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    # [BK, BV]
 | 
			
		||||
    b_h = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
 | 
			
		||||
    if USE_INITIAL_STATE:
 | 
			
		||||
        p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
 | 
			
		||||
 | 
			
		||||
    for i_t in range(NT):
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
 | 
			
		||||
        tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_h += tl.dot(b_k, b_v, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
    if STORE_FINAL_STATE:
 | 
			
		||||
        p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_linear_attn_fwd_kernel_o(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    h,
 | 
			
		||||
    o,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    scale,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
    m_s = o_i[:, None] >= o_i[None, :]
 | 
			
		||||
 | 
			
		||||
    b_o = tl.zeros([BT, BV], dtype=tl.float32)
 | 
			
		||||
    b_s = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
    for i_k in range(tl.cdiv(K, BK)):
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
        p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_h = tl.load(p_h, boundary_check=(0, 1))
 | 
			
		||||
        b_o += tl.dot(b_q, b_h, allow_tf32=False)
 | 
			
		||||
        b_s += tl.dot(b_q, b_k, allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
    b_s = tl.where(m_s, b_s, 0)
 | 
			
		||||
    p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
    b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
 | 
			
		||||
    p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_linear_attn_bwd_kernel_dh(
 | 
			
		||||
    q,
 | 
			
		||||
    do,
 | 
			
		||||
    dh,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    scale,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    NT: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
 | 
			
		||||
    # [BK, BV]
 | 
			
		||||
    b_dh = tl.zeros([BK, BV], dtype=tl.float32)
 | 
			
		||||
    for i_t in range(NT - 1, -1, -1):
 | 
			
		||||
        p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
 | 
			
		||||
        tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BT]
 | 
			
		||||
        b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
        b_q = (b_q * scale).to(b_q.dtype)
 | 
			
		||||
        # [BT, V]
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def chunk_linear_attn_bwd_kernel_dqkv(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    h,
 | 
			
		||||
    do,
 | 
			
		||||
    dh,
 | 
			
		||||
    dq,
 | 
			
		||||
    dk,
 | 
			
		||||
    dv,
 | 
			
		||||
    s_qk_h,
 | 
			
		||||
    s_qk_t,
 | 
			
		||||
    s_qk_d,
 | 
			
		||||
    s_vo_h,
 | 
			
		||||
    s_vo_t,
 | 
			
		||||
    s_vo_d,
 | 
			
		||||
    s_h_h,
 | 
			
		||||
    s_h_t,
 | 
			
		||||
    scale,
 | 
			
		||||
    H: tl.constexpr,
 | 
			
		||||
    T: tl.constexpr,
 | 
			
		||||
    K: tl.constexpr,
 | 
			
		||||
    V: tl.constexpr,
 | 
			
		||||
    BT: tl.constexpr,
 | 
			
		||||
    BK: tl.constexpr,
 | 
			
		||||
    BV: tl.constexpr,
 | 
			
		||||
    NT: tl.constexpr
 | 
			
		||||
):
 | 
			
		||||
    i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
 | 
			
		||||
    n_bh = tl.num_programs(2)
 | 
			
		||||
    o_i = tl.arange(0, BT)
 | 
			
		||||
 | 
			
		||||
    p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
 | 
			
		||||
    p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
 | 
			
		||||
    b_q = tl.load(p_q, boundary_check=(0, 1))
 | 
			
		||||
    b_k = tl.load(p_k, boundary_check=(0, 1))
 | 
			
		||||
    b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
 | 
			
		||||
    b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
 | 
			
		||||
 | 
			
		||||
    b_dq = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
    b_dk = tl.zeros([BT, BK], dtype=tl.float32)
 | 
			
		||||
    b_ds = tl.zeros([BT, BT], dtype=tl.float32)
 | 
			
		||||
    for i_v in range(tl.cdiv(V, BV)):
 | 
			
		||||
        p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
 | 
			
		||||
        p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
 | 
			
		||||
        p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_v = tl.load(p_v, boundary_check=(0, 1))
 | 
			
		||||
        b_do = tl.load(p_do, boundary_check=(0, 1))
 | 
			
		||||
        # [BV, BK]
 | 
			
		||||
        b_h = tl.load(p_h, boundary_check=(0, 1))
 | 
			
		||||
        # [BK, BV]
 | 
			
		||||
        b_dh = tl.load(p_dh, boundary_check=(0, 1))
 | 
			
		||||
        # [BT, BT]
 | 
			
		||||
        b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
 | 
			
		||||
        # [BT, BK]
 | 
			
		||||
        b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
 | 
			
		||||
        b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
 | 
			
		||||
        # [BT, BV]
 | 
			
		||||
        b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
 | 
			
		||||
        tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    # [BT, BT]
 | 
			
		||||
    b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
 | 
			
		||||
    # [BT, BK]
 | 
			
		||||
    b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
 | 
			
		||||
    b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
 | 
			
		||||
 | 
			
		||||
    p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
 | 
			
		||||
    tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
    tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChunkLinearAttentionFunction(torch.autograd.Function):
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @custom_fwd
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def forward(ctx, q, k, v, scale, initial_state, output_final_state):
 | 
			
		||||
        B, H, T, K, V = *q.shape, v.shape[-1]
 | 
			
		||||
        BT = 64
 | 
			
		||||
        BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
 | 
			
		||||
        NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 4 if BK == 64 else 2
 | 
			
		||||
        ctx.scale = scale
 | 
			
		||||
 | 
			
		||||
        final_state = None
 | 
			
		||||
        if output_final_state:
 | 
			
		||||
            final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
 | 
			
		||||
 | 
			
		||||
        h = q.new_empty(B, H, NT * K, V)
 | 
			
		||||
        grid = (NK, NV, B * H)
 | 
			
		||||
        chunk_linear_attn_fwd_kernel_h[grid](
 | 
			
		||||
            k, v, h, initial_state, final_state,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            h.stride(1), h.stride(2),
 | 
			
		||||
            H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
            USE_INITIAL_STATE=initial_state is not None,
 | 
			
		||||
            STORE_FINAL_STATE=output_final_state,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
        grid = (NV, NT, B * H)
 | 
			
		||||
        o = torch.empty_like(v)
 | 
			
		||||
        chunk_linear_attn_fwd_kernel_o[grid](
 | 
			
		||||
            q, k, v, h, o,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            h.stride(1), h.stride(2),
 | 
			
		||||
            scale,
 | 
			
		||||
            H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ctx.save_for_backward(q, k, v, h)
 | 
			
		||||
        return o.to(q.dtype), final_state
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @custom_bwd
 | 
			
		||||
    @contiguous
 | 
			
		||||
    def backward(ctx, do, d_ht=None):
 | 
			
		||||
        q, k, v, h = ctx.saved_tensors
 | 
			
		||||
 | 
			
		||||
        B, H, T, K, V = *q.shape, v.shape[-1]
 | 
			
		||||
        BT = 64
 | 
			
		||||
        BK, BV = min(64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
 | 
			
		||||
        NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 4 if BK == 64 else 2
 | 
			
		||||
        scale = ctx.scale
 | 
			
		||||
 | 
			
		||||
        dh = q.new_empty(B, H, NT * K, V)
 | 
			
		||||
        grid = (NK, NV, B * H)
 | 
			
		||||
        chunk_linear_attn_bwd_kernel_dh[grid](
 | 
			
		||||
            q, do, dh,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            dh.stride(1), dh.stride(2),
 | 
			
		||||
            scale,
 | 
			
		||||
            H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        grid = (NK, NT, B * H)
 | 
			
		||||
        dq = torch.empty_like(q)
 | 
			
		||||
        dk = torch.empty_like(k)
 | 
			
		||||
        dv = v.new_empty(NK, *v.shape)
 | 
			
		||||
        num_stages = 1
 | 
			
		||||
        num_warps = 4 if BK == 64 else 2
 | 
			
		||||
        chunk_linear_attn_bwd_kernel_dqkv[grid](
 | 
			
		||||
            q, k, v, h, do, dh, dq, dk, dv,
 | 
			
		||||
            q.stride(1), q.stride(2), q.stride(3),
 | 
			
		||||
            v.stride(1), v.stride(2), v.stride(3),
 | 
			
		||||
            dh.stride(1), dh.stride(2),
 | 
			
		||||
            scale,
 | 
			
		||||
            H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
 | 
			
		||||
            num_warps=num_warps,
 | 
			
		||||
            num_stages=num_stages
 | 
			
		||||
        )
 | 
			
		||||
        dv = dv.sum(0)
 | 
			
		||||
        return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chunk_linear_attn(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    v: torch.Tensor,
 | 
			
		||||
    scale: float = -1,
 | 
			
		||||
    initial_state: torch.Tensor = None,
 | 
			
		||||
    output_final_state: bool = False,
 | 
			
		||||
    normalize: bool = True
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    if scale == -1:
 | 
			
		||||
        scale = q.shape[-1] ** -0.5
 | 
			
		||||
    if initial_state is not None:
 | 
			
		||||
        initial_state = initial_state.detach()
 | 
			
		||||
    o, final_state = ChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
 | 
			
		||||
 | 
			
		||||
    if normalize:
 | 
			
		||||
        o = normalize_output(q * scale, k, o)
 | 
			
		||||
 | 
			
		||||
    return o, final_state
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user