#include "ATen/ATen.h" #include #include #include #include "util.h" #include "element_wise.h" using torch::Tensor; void gemm_fp16_cublas(const void *a, const void *b, void *c, int m, int n, int k, bool output_fp32); // based on `kernel_wkv_forward`, fusing more operations __global__ void kernel_wkv_forward_new( const int B, const int T, const int C, const float *__restrict__ const _w, const float *__restrict__ const _u, const float *__restrict__ const _k, const float *__restrict__ const _v, const half *__restrict__ const r, half *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; const int _b = idx / C; const int _c = idx % C; const int _offset = _b * T * C + _c; const int _state_offset = _b * C + _c; float u = _u[_c]; float w = _w[_c]; const float *__restrict__ const k = _k + _offset; const float *__restrict__ const v = _v + _offset; half *__restrict__ const y = _y + _offset; float aa = _aa[_state_offset]; float bb = _bb[_state_offset]; float pp = _pp[_state_offset]; for (int i = 0; i < T; i++) { const int ii = i * C; const float kk = k[ii]; const float vv = v[ii]; float ww = u + kk; float p = max(pp, ww); float e1 = exp(pp - p); float e2 = exp(ww - p); y[ii] = __float2half((e1 * aa + e2 * vv) / (e1 * bb + e2)); ww = w + pp; p = max(ww, kk); e1 = exp(ww - p); e2 = exp(kk - p); aa = e1 * aa + e2 * vv; bb = e1 * bb + e2; pp = p; } _aa[_state_offset] = aa; _bb[_state_offset] = bb; _pp[_state_offset] = pp; } void cuda_wkv_forward_new(int B, int T, int C, float *w, float *u, float *k, float *v, half *r, half *y, float *aa, float *bb, float *pp) { dim3 threadsPerBlock(min(C, 32)); assert(B * C % threadsPerBlock.x == 0); dim3 numBlocks(B * C / threadsPerBlock.x); kernel_wkv_forward_new<<>>(B, T, C, w, u, k, v, r, y, aa, bb, pp); } __global__ void _att_mix(const half *xx, const half *sx, const half *k_mix, const half *v_mix, const half *r_mix, const int outer_size, const int inner_size, half *kx, half *vx, half *rx) { for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size; idx2 += blockDim.x * gridDim.x) { half k_mix_ = k_mix[idx2]; half v_mix_ = v_mix[idx2]; half r_mix_ = r_mix[idx2]; for (int row = 0; row < outer_size; ++row) { int idx1 = row * inner_size + idx2; half xx_ = xx[idx1]; half sx_ = sx[idx1]; kx[idx1] = __hadd(__hmul(xx_, k_mix_), __hmul(sx_, __hsub(__float2half(1), k_mix_))); vx[idx1] = __hadd(__hmul(xx_, v_mix_), __hmul(sx_, __hsub(__float2half(1), v_mix_))); rx[idx1] = __hadd(__hmul(xx_, r_mix_), __hmul(sx_, __hsub(__float2half(1), r_mix_))); } } } void att_mix(const half *xx, const half *sx, const half *k_mix, const half *v_mix, const half *r_mix, const int outer_size, const int inner_size, half *kx, half *vx, half *rx) { // 256 is good enough on most GPUs const int32_t BLOCK_SIZE = 256; assert(inner_size % BLOCK_SIZE == 0); _att_mix<<>>( xx, sx, k_mix, v_mix, r_mix, outer_size, inner_size, kx, vx, rx); } struct InplaceSigmoid { __device__ __forceinline__ half operator()(int i) const { ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i])))); } half *ptr; }; struct InplaceMul { __device__ __forceinline__ half operator()(int i) const { y[i] = __hmul(x[i], y[i]); } half *y; half *x; }; /* Equivalent Python code: xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) kx = xx * k_mix + sx * (1 - k_mix) vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) r = torch.sigmoid(gemm(rx, rw)) k = gemm(kx, kw, output_dtype=torch.float32) v = gemm(vx, vw, output_dtype=torch.float32) T = x.shape[0] for t in range(T): kk = k[t] vv = v[t] ww = t_first + kk p = torch.maximum(pp, ww) e1 = torch.exp(pp - p) e2 = torch.exp(ww - p) sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) ww = t_decay + pp p = torch.maximum(ww, kk) e1 = torch.exp(ww - p) e2 = torch.exp(kk - p) aa = e1 * aa + e2 * vv bb = e1 * bb + e2 pp = p out = gemm(r * sx, ow) return x + out, xx[-1,:], aa, bb, pp */ Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb, Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out) { Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0); char* buf_ptr = (char*)buf.data_ptr(); half* kx = (half*)buf_ptr; half* vx = kx + x.numel(); half* rx = vx + x.numel(); half* wkv_y = rx + x.numel(); att_mix(data_ptr(xx), data_ptr(sx), data_ptr(k_mix), data_ptr(v_mix), data_ptr(r_mix), xx.size(0), xx.size(1), kx, vx, rx); float* k = reinterpret_cast(wkv_y + x.numel()); float* v = k + x.size(0) * kw.size(1); half* r = reinterpret_cast(v + x.size(0) * vw.size(1)); gemm_fp16_cublas(kx, kw.data_ptr(), k, x.size(0), kw.size(1), kw.size(0), true); gemm_fp16_cublas(vx, vw.data_ptr(), v, x.size(0), vw.size(1), vw.size(0), true); gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), rw.size(0), false); element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1)); cuda_wkv_forward_new(1, x.size(0), x.size(1), data_ptr(t_decay), data_ptr(t_first), k, v, r, wkv_y, data_ptr(aa), data_ptr(bb), data_ptr(pp)); element_wise(InplaceMul{wkv_y, r}, x.numel()); gemm_fp16_cublas(wkv_y, ow.data_ptr(), x_plus_out.data_ptr(), x.size(0), ow.size(1), ow.size(0), false); x_plus_out += x; return xx; }