#include "ATen/ATen.h" #include #include #include #include "element_wise.h" #include "util.h" // Equivalent Python code: // s1 = t_first * a + s // s2 = a + t_decay * s struct Fused1 { const float *t_first; const float *t_decay; const float *a; const float *s; const int32_t inner_size; /* out */ float *s1; /* out */ float *s2; __device__ void operator()(int i) const { const int j = i / inner_size; s1[i] = t_first[j] * a[i] + s[i]; s2[i] = a[i] + t_decay[j] * s[i]; } }; /* Equivalent Python code: kx = xx * k_mix + sx * (1 - k_mix) vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) */ struct Mix { const half *xx; const half *sx; const half *k_mix; const half *v_mix; const half *r_mix; /* out */ half *kx; /* out */ half *vx; /* out */ half *rx; __device__ void operator()(int i) const { half xx_ = xx[i]; half sx_ = sx[i]; half k_mix_ = k_mix[i]; half v_mix_ = v_mix[i]; half r_mix_ = r_mix[i]; kx[i] = __hadd(__hmul(xx_, k_mix_), __hmul(sx_, __hsub(__float2half(1), k_mix_))); vx[i] = __hadd(__hmul(xx_, v_mix_), __hmul(sx_, __hsub(__float2half(1), v_mix_))); rx[i] = __hadd(__hmul(xx_, r_mix_), __hmul(sx_, __hsub(__float2half(1), r_mix_))); } }; using torch::Tensor; void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c); Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b, Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix, Tensor r_mix, Tensor kw, /* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw, /* imm */ Tensor rx, Tensor ow, Tensor t_first, /* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v, /* imm */ Tensor r, /* imm */ Tensor s1, /* out */ Tensor x_plus_out, /* out */ Tensor s2) { Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); element_wise(Mix{data_ptr(xx), data_ptr(sx), data_ptr(k_mix), data_ptr(v_mix), data_ptr(r_mix), data_ptr(kx), data_ptr(vx), data_ptr(rx)}, x.numel()); int H = t_decay.size(0); int S = x.size(-1) / H; gemm_fp16_cublas_tensor(rx, rw, r); r = at::reshape(r, {H, 1, S}); gemm_fp16_cublas_tensor(kx, kw, k); k = at::reshape(k, {H, S, 1}); gemm_fp16_cublas_tensor(vx, vw, v); v = at::reshape(v, {H, 1, S}); { Tensor a = at::matmul(k, v); // s1 = t_first * a + s // s2 = a + t_decay * s element_wise(Fused1{data_ptr(t_first), data_ptr(t_decay), data_ptr(a), data_ptr(s), static_cast(a.size(1) * a.size(2)), data_ptr(s1), data_ptr(s2)}, a.numel()); } Tensor out = at::matmul(r, s1); out = at::flatten(out); out = at::squeeze(at::group_norm(at::unsqueeze(out, 0), H, lx_w, lx_b), 0); out = at::_cast_Half(out); gemm_fp16_cublas_tensor(out, ow, x_plus_out); x_plus_out += x; return xx; }