166 lines
5.5 KiB
Plaintext
166 lines
5.5 KiB
Plaintext
|
#include "ATen/ATen.h"
|
||
|
#include <cuda_fp16.h>
|
||
|
#include <cuda_runtime.h>
|
||
|
#include <torch/extension.h>
|
||
|
|
||
|
#include "element_wise.h"
|
||
|
#include "util.h"
|
||
|
|
||
|
using torch::Tensor;
|
||
|
|
||
|
void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m,
|
||
|
int ori_n, int ori_k, bool output_fp32);
|
||
|
|
||
|
__global__ void _ffn_seq_mix(const half *xx, const half *sx, const half *k_mix,
|
||
|
const half *r_mix, const int outer_size,
|
||
|
const int inner_size, half *kx, half *rx) {
|
||
|
for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size;
|
||
|
idx2 += blockDim.x * gridDim.x) {
|
||
|
half k_mix_ = k_mix[idx2];
|
||
|
half r_mix_ = r_mix[idx2];
|
||
|
for (int row = 0; row < outer_size; ++row) {
|
||
|
int idx1 = row * inner_size + idx2;
|
||
|
half xx_ = xx[idx1];
|
||
|
half sx_ = sx[idx1];
|
||
|
kx[idx1] = __hadd(__hmul(xx_, k_mix_),
|
||
|
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||
|
rx[idx1] = __hadd(__hmul(xx_, r_mix_),
|
||
|
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void ffn_seq_mix(const half *xx, const half *sx, const half *k_mix,
|
||
|
const half *r_mix, const int outer_size, const int inner_size,
|
||
|
half *kx, half *rx) {
|
||
|
// 256 is good enough on most GPUs
|
||
|
const int32_t BLOCK_SIZE = 256;
|
||
|
assert(inner_size % BLOCK_SIZE == 0);
|
||
|
_ffn_seq_mix<<<inner_size / BLOCK_SIZE, BLOCK_SIZE>>>(
|
||
|
xx, sx, k_mix, r_mix, outer_size, inner_size, kx, rx);
|
||
|
}
|
||
|
|
||
|
struct InplaceSigmoid {
|
||
|
__device__ __forceinline__ void operator()(int i) const {
|
||
|
ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i]))));
|
||
|
}
|
||
|
half *ptr;
|
||
|
};
|
||
|
|
||
|
struct InplaceReLUAndSquare {
|
||
|
__device__ __forceinline__ void operator()(int i) const {
|
||
|
// __hmax is not defined in old cuda
|
||
|
if (__hgt(ptr[i], __float2half(0))) {
|
||
|
ptr[i] = __hmul(ptr[i], ptr[i]);
|
||
|
} else {
|
||
|
ptr[i] = __float2half(0);
|
||
|
}
|
||
|
}
|
||
|
half *ptr;
|
||
|
};
|
||
|
|
||
|
struct InplaceFma {
|
||
|
__device__ __forceinline__ void operator()(int i) const {
|
||
|
a[i] = __hfma(a[i], b[i], c[i]);
|
||
|
}
|
||
|
half *a;
|
||
|
const half *b;
|
||
|
const half *c;
|
||
|
};
|
||
|
|
||
|
/*
|
||
|
Equivalent Python code:
|
||
|
|
||
|
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||
|
sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
|
||
|
kx = xx * k_mix + sx * (1 - k_mix)
|
||
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||
|
|
||
|
r = torch.sigmoid(gemm(rx, rw))
|
||
|
vx = torch.square(torch.relu(gemm(kx, kw)))
|
||
|
out = r * gemm(vx, vw)
|
||
|
return x + out, xx[-1,:]
|
||
|
*/
|
||
|
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||
|
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||
|
/* imm */ Tensor buf,
|
||
|
/* out */ Tensor x_plus_out) {
|
||
|
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||
|
sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0);
|
||
|
char *buf_ptr = (char *)buf.data_ptr();
|
||
|
half *kx = (half *)buf_ptr;
|
||
|
half *rx = kx + x.numel();
|
||
|
half *vx = rx + x.numel();
|
||
|
half *r = vx + x.size(0) * kw.size(1);
|
||
|
ffn_seq_mix(data_ptr<half>(xx), data_ptr<half>(sx), data_ptr<half>(k_mix),
|
||
|
data_ptr<half>(r_mix), xx.size(0), xx.size(1), kx, rx);
|
||
|
|
||
|
gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), x.size(1),
|
||
|
false);
|
||
|
element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1));
|
||
|
gemm_fp16_cublas(kx, kw.data_ptr(), vx, x.size(0), kw.size(1), x.size(1),
|
||
|
false);
|
||
|
element_wise(InplaceReLUAndSquare{vx}, x.size(0) * kw.size(1));
|
||
|
gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), x.size(0),
|
||
|
vw.size(1), vw.size(0), false);
|
||
|
element_wise(InplaceFma{data_ptr<half>(x_plus_out), r, data_ptr<half>(x)},
|
||
|
x_plus_out.numel());
|
||
|
return xx;
|
||
|
}
|
||
|
|
||
|
struct FfnOneMix {
|
||
|
__device__ __forceinline__ void operator()(int idx) {
|
||
|
half k_mix_ = k_mix[idx];
|
||
|
half r_mix_ = r_mix[idx];
|
||
|
half xx_ = xx[idx];
|
||
|
half sx_ = sx[idx];
|
||
|
kx[idx] = __hadd(__hmul(xx_, k_mix_),
|
||
|
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||
|
rx[idx] = __hadd(__hmul(xx_, r_mix_),
|
||
|
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||
|
}
|
||
|
half *k_mix;
|
||
|
half *r_mix;
|
||
|
half *xx;
|
||
|
half *sx;
|
||
|
half *kx;
|
||
|
half *rx;
|
||
|
};
|
||
|
|
||
|
/*
|
||
|
Equivalent Python code:
|
||
|
|
||
|
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||
|
kx = xx * k_mix + sx * (1 - k_mix)
|
||
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||
|
|
||
|
r = torch.sigmoid(gemm(rx, rw))
|
||
|
vx = torch.square(torch.relu(gemm(kx, kw)))
|
||
|
out = r * gemm(vx, vw)
|
||
|
return x + out, xx
|
||
|
*/
|
||
|
Tensor ffn_one(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||
|
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||
|
/* imm */ Tensor buf,
|
||
|
/* out */ Tensor x_plus_out) {
|
||
|
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||
|
char *buf_ptr = (char *)buf.data_ptr();
|
||
|
half *kx = (half *)buf_ptr;
|
||
|
half *rx = kx + x.numel();
|
||
|
half *vx = rx + x.numel();
|
||
|
half *r = vx + x.size(0) * kw.size(1);
|
||
|
element_wise(FfnOneMix{data_ptr<half>(k_mix), data_ptr<half>(r_mix),
|
||
|
data_ptr<half>(xx), data_ptr<half>(sx), kx, rx},
|
||
|
x.numel());
|
||
|
// vector * matrix, so m = 1
|
||
|
gemm_fp16_cublas(rx, rw.data_ptr(), r, 1, rw.size(1), rw.size(0), false);
|
||
|
element_wise(InplaceSigmoid{r}, rw.size(1));
|
||
|
gemm_fp16_cublas(kx, kw.data_ptr(), vx, 1, kw.size(1), kw.size(0), false);
|
||
|
element_wise(InplaceReLUAndSquare{vx}, kw.size(1));
|
||
|
gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), 1, vw.size(1),
|
||
|
vw.size(0), false);
|
||
|
element_wise(InplaceFma{data_ptr<half>(x_plus_out), r, data_ptr<half>(x)},
|
||
|
x_plus_out.numel());
|
||
|
return xx;
|
||
|
}
|