#include "ATen/ATen.h" #include #include #include #include "element_wise.h" #include "util.h" // Equivalent Python code: // ww = t_first + k // p = torch.maximum(pp, ww) // e1 = torch.exp(pp - p) // e2 = torch.exp(ww - p) // wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) // ww = t_decay + pp // p = torch.maximum(ww, k) // e1 = torch.exp(ww - p) // e2 = torch.exp(k - p) // t1 = e1 * aa + e2 * v // t2 = e1 * bb + e2 // r = r * wkv // return t1, t2, p, r struct WkvForwardOne { const float *t_first; const float *k; const float *pp; const float *aa; const float *bb; const float *t_decay; const float *v; /* out */ float *t1; /* out */ float *t2; /* out */ float *p; /* in & out */ half *r; __device__ void operator()(int i) const { float ww = t_first[i] + k[i]; float pp_ = pp[i]; float p_ = (pp_ > ww) ? pp_ : ww; float e1 = expf(pp_ - p_); float e2 = expf(ww - p_); float aa_ = aa[i]; float bb_ = bb[i]; float v_ = v[i]; r[i] = __hmul(r[i], __float2half(((e1 * aa_ + e2 * v_) / (e1 * bb_ + e2)))); ww = t_decay[i] + pp_; float k_ = k[i]; p_ = (ww > k_) ? ww : k_; e1 = expf(ww - p_); e2 = expf(k_ - p_); t1[i] = e1 * aa_ + e2 * v_; t2[i] = e1 * bb_ + e2; p[i] = p_; } }; /* Equivalent Python code: kx = xx * k_mix + sx * (1 - k_mix) vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) */ struct Mix { const half *xx; const half *sx; const half *k_mix; const half *v_mix; const half *r_mix; /* out */ half *kx; /* out */ half *vx; /* out */ half *rx; __device__ void operator()(int i) const { half xx_ = xx[i]; half sx_ = sx[i]; half k_mix_ = k_mix[i]; half v_mix_ = v_mix[i]; half r_mix_ = r_mix[i]; kx[i] = __hadd(__hmul(xx_, k_mix_), __hmul(sx_, __hsub(__float2half(1), k_mix_))); vx[i] = __hadd(__hmul(xx_, v_mix_), __hmul(sx_, __hsub(__float2half(1), v_mix_))); rx[i] = __hadd(__hmul(xx_, r_mix_), __hmul(sx_, __hsub(__float2half(1), r_mix_))); } }; using torch::Tensor; void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, Tensor v_mix, Tensor r_mix, Tensor kw, /* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw, /* imm */ Tensor rx, Tensor ow, Tensor t_first, /* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb, Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r, /* out */ Tensor x_plus_out, /* out */ Tensor t1, /* out */ Tensor t2, /* out */ Tensor p) { Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); element_wise(Mix{data_ptr(xx), data_ptr(sx), data_ptr(k_mix), data_ptr(v_mix), data_ptr(r_mix), data_ptr(kx), data_ptr(vx), data_ptr(rx)}, x.numel()); gemm_fp16_cublas(kx, kw, k); gemm_fp16_cublas(vx, vw, v); gemm_fp16_cublas(rx, rw, r); at::sigmoid_(r); element_wise(WkvForwardOne{data_ptr(t_first), data_ptr(k), data_ptr(pp), data_ptr(aa), data_ptr(bb), data_ptr(t_decay), data_ptr(v), data_ptr(t1), data_ptr(t2), data_ptr(p), data_ptr(r)}, x.numel()); gemm_fp16_cublas(r, ow, x_plus_out); x_plus_out += x; return xx; }