110 lines
3.2 KiB
Plaintext
110 lines
3.2 KiB
Plaintext
|
#include "ATen/ATen.h"
|
||
|
#include <cuda_fp16.h>
|
||
|
#include <cuda_runtime.h>
|
||
|
#include <torch/extension.h>
|
||
|
|
||
|
#include "element_wise.h"
|
||
|
#include "util.h"
|
||
|
|
||
|
// Equivalent Python code:
|
||
|
// s1 = t_first * a + s
|
||
|
// s2 = a + t_decay * s
|
||
|
struct Fused1 {
|
||
|
const float *t_first;
|
||
|
const float *t_decay;
|
||
|
const float *a;
|
||
|
const float *s;
|
||
|
const int32_t inner_size;
|
||
|
/* out */ float *s1;
|
||
|
/* out */ float *s2;
|
||
|
|
||
|
__device__ void operator()(int i) const {
|
||
|
const int j = i / inner_size;
|
||
|
s1[i] = t_first[j] * a[i] + s[i];
|
||
|
s2[i] = a[i] + t_decay[j] * s[i];
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/*
|
||
|
Equivalent Python code:
|
||
|
kx = xx * k_mix + sx * (1 - k_mix)
|
||
|
vx = xx * v_mix + sx * (1 - v_mix)
|
||
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||
|
*/
|
||
|
|
||
|
struct Mix {
|
||
|
const half *xx;
|
||
|
const half *sx;
|
||
|
const half *k_mix;
|
||
|
const half *v_mix;
|
||
|
const half *r_mix;
|
||
|
/* out */ half *kx;
|
||
|
/* out */ half *vx;
|
||
|
/* out */ half *rx;
|
||
|
|
||
|
__device__ void operator()(int i) const {
|
||
|
half xx_ = xx[i];
|
||
|
half sx_ = sx[i];
|
||
|
half k_mix_ = k_mix[i];
|
||
|
half v_mix_ = v_mix[i];
|
||
|
half r_mix_ = r_mix[i];
|
||
|
kx[i] = __hadd(__hmul(xx_, k_mix_),
|
||
|
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||
|
vx[i] = __hadd(__hmul(xx_, v_mix_),
|
||
|
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
|
||
|
rx[i] = __hadd(__hmul(xx_, r_mix_),
|
||
|
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||
|
}
|
||
|
};
|
||
|
|
||
|
using torch::Tensor;
|
||
|
|
||
|
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||
|
|
||
|
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
|
||
|
Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix,
|
||
|
Tensor r_mix, Tensor kw,
|
||
|
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx,
|
||
|
Tensor rw,
|
||
|
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
|
||
|
/* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v,
|
||
|
/* imm */ Tensor r, /* imm */ Tensor s1,
|
||
|
/* out */ Tensor x_plus_out, /* out */ Tensor s2) {
|
||
|
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||
|
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
|
||
|
data_ptr<half>(k_mix), data_ptr<half>(v_mix),
|
||
|
data_ptr<half>(r_mix), data_ptr<half>(kx),
|
||
|
data_ptr<half>(vx), data_ptr<half>(rx)},
|
||
|
x.numel());
|
||
|
|
||
|
int H = t_decay.size(0);
|
||
|
int S = x.size(-1) / H;
|
||
|
gemm_fp16_cublas_tensor(rx, rw, r);
|
||
|
r = at::reshape(r, {H, 1, S});
|
||
|
gemm_fp16_cublas_tensor(kx, kw, k);
|
||
|
k = at::reshape(k, {H, S, 1});
|
||
|
gemm_fp16_cublas_tensor(vx, vw, v);
|
||
|
v = at::reshape(v, {H, 1, S});
|
||
|
|
||
|
{
|
||
|
Tensor a = at::matmul(k, v);
|
||
|
|
||
|
// s1 = t_first * a + s
|
||
|
// s2 = a + t_decay * s
|
||
|
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay),
|
||
|
data_ptr<float>(a), data_ptr<float>(s),
|
||
|
static_cast<int32_t>(a.size(1) * a.size(2)),
|
||
|
data_ptr<float>(s1), data_ptr<float>(s2)},
|
||
|
a.numel());
|
||
|
}
|
||
|
|
||
|
Tensor out = at::matmul(r, s1);
|
||
|
out = at::flatten(out);
|
||
|
out = at::squeeze(at::group_norm(at::unsqueeze(out, 0), H, lx_w, lx_b), 0);
|
||
|
out = at::_cast_Half(out);
|
||
|
|
||
|
gemm_fp16_cublas_tensor(out, ow, x_plus_out);
|
||
|
x_plus_out += x;
|
||
|
return xx;
|
||
|
}
|