upgrade cuda-beta
This commit is contained in:
10
backend-python/rwkv_pip/beta/cuda/att_one.cu
vendored
10
backend-python/rwkv_pip/beta/cuda/att_one.cu
vendored
@@ -88,7 +88,7 @@ struct Mix {
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
||||
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
Tensor v_mix, Tensor r_mix, Tensor kw,
|
||||
@@ -105,9 +105,9 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
data_ptr<half>(vx), data_ptr<half>(rx)},
|
||||
x.numel());
|
||||
|
||||
gemm_fp16_cublas(kx, kw, k);
|
||||
gemm_fp16_cublas(vx, vw, v);
|
||||
gemm_fp16_cublas(rx, rw, r);
|
||||
gemm_fp16_cublas_tensor(kx, kw, k);
|
||||
gemm_fp16_cublas_tensor(vx, vw, v);
|
||||
gemm_fp16_cublas_tensor(rx, rw, r);
|
||||
at::sigmoid_(r);
|
||||
|
||||
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
|
||||
@@ -118,7 +118,7 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
data_ptr<half>(r)},
|
||||
x.numel());
|
||||
|
||||
gemm_fp16_cublas(r, ow, x_plus_out);
|
||||
gemm_fp16_cublas_tensor(r, ow, x_plus_out);
|
||||
x_plus_out += x;
|
||||
return xx;
|
||||
}
|
||||
|
||||
109
backend-python/rwkv_pip/beta/cuda/att_one_v5.cu
vendored
Normal file
109
backend-python/rwkv_pip/beta/cuda/att_one_v5.cu
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "element_wise.h"
|
||||
#include "util.h"
|
||||
|
||||
// Equivalent Python code:
|
||||
// s1 = t_first * a + s
|
||||
// s2 = a + t_decay * s
|
||||
struct Fused1 {
|
||||
const float *t_first;
|
||||
const float *t_decay;
|
||||
const float *a;
|
||||
const float *s;
|
||||
const int32_t inner_size;
|
||||
/* out */ float *s1;
|
||||
/* out */ float *s2;
|
||||
|
||||
__device__ void operator()(int i) const {
|
||||
const int j = i / inner_size;
|
||||
s1[i] = t_first[j] * a[i] + s[i];
|
||||
s2[i] = a[i] + t_decay[j] * s[i];
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
Equivalent Python code:
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
*/
|
||||
|
||||
struct Mix {
|
||||
const half *xx;
|
||||
const half *sx;
|
||||
const half *k_mix;
|
||||
const half *v_mix;
|
||||
const half *r_mix;
|
||||
/* out */ half *kx;
|
||||
/* out */ half *vx;
|
||||
/* out */ half *rx;
|
||||
|
||||
__device__ void operator()(int i) const {
|
||||
half xx_ = xx[i];
|
||||
half sx_ = sx[i];
|
||||
half k_mix_ = k_mix[i];
|
||||
half v_mix_ = v_mix[i];
|
||||
half r_mix_ = r_mix[i];
|
||||
kx[i] = __hadd(__hmul(xx_, k_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||||
vx[i] = __hadd(__hmul(xx_, v_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
|
||||
rx[i] = __hadd(__hmul(xx_, r_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||||
}
|
||||
};
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
|
||||
Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix,
|
||||
Tensor r_mix, Tensor kw,
|
||||
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx,
|
||||
Tensor rw,
|
||||
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
|
||||
/* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v,
|
||||
/* imm */ Tensor r, /* imm */ Tensor s1,
|
||||
/* out */ Tensor x_plus_out, /* out */ Tensor s2) {
|
||||
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||||
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
|
||||
data_ptr<half>(k_mix), data_ptr<half>(v_mix),
|
||||
data_ptr<half>(r_mix), data_ptr<half>(kx),
|
||||
data_ptr<half>(vx), data_ptr<half>(rx)},
|
||||
x.numel());
|
||||
|
||||
int H = t_decay.size(0);
|
||||
int S = x.size(-1) / H;
|
||||
gemm_fp16_cublas_tensor(rx, rw, r);
|
||||
r = at::reshape(r, {H, 1, S});
|
||||
gemm_fp16_cublas_tensor(kx, kw, k);
|
||||
k = at::reshape(k, {H, S, 1});
|
||||
gemm_fp16_cublas_tensor(vx, vw, v);
|
||||
v = at::reshape(v, {H, 1, S});
|
||||
|
||||
{
|
||||
Tensor a = at::matmul(k, v);
|
||||
|
||||
// s1 = t_first * a + s
|
||||
// s2 = a + t_decay * s
|
||||
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay),
|
||||
data_ptr<float>(a), data_ptr<float>(s),
|
||||
static_cast<int32_t>(a.size(1) * a.size(2)),
|
||||
data_ptr<float>(s1), data_ptr<float>(s2)},
|
||||
a.numel());
|
||||
}
|
||||
|
||||
Tensor out = at::matmul(r, s1);
|
||||
out = at::flatten(out);
|
||||
out = at::squeeze(at::group_norm(at::unsqueeze(out, 0), H, lx_w, lx_b), 0);
|
||||
out = at::_cast_Half(out);
|
||||
|
||||
gemm_fp16_cublas_tensor(out, ow, x_plus_out);
|
||||
x_plus_out += x;
|
||||
return xx;
|
||||
}
|
||||
1
backend-python/rwkv_pip/beta/cuda/att_seq.cu
vendored
1
backend-python/rwkv_pip/beta/cuda/att_seq.cu
vendored
@@ -8,7 +8,6 @@
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
||||
void gemm_fp16_cublas(const void *a, const void *b, void *c, int m,
|
||||
int n, int k, bool output_fp32);
|
||||
|
||||
|
||||
@@ -70,11 +70,59 @@ void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m,
|
||||
cuda_c_data_type, cublas_ldc, compute_type, algo));
|
||||
}
|
||||
|
||||
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||
// comptiable with rwkv one mode, 1-D tensor * 2-D tensor
|
||||
const int m = a.dense_dim() == 1 ? 1 : a.size(0);
|
||||
const int n = b.size(1);
|
||||
const int k = b.size(0);
|
||||
gemm_fp16_cublas(a.data_ptr(), b.data_ptr(), c.data_ptr(), m, n, k,
|
||||
c.dtype() == torch::kFloat32);
|
||||
/*
|
||||
NOTE: blas gemm is column-major by default, but we need row-major output.
|
||||
The data of row-major, transposed matrix is exactly the same as the
|
||||
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
||||
*/
|
||||
void gemm_fp16_cublas_tensor(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||
if (a.sizes().size() == 1) {
|
||||
assert(b.sizes().size() == 2);
|
||||
a = at::unsqueeze(a, 0);
|
||||
}
|
||||
const auto cuda_data_type = CUDA_R_16F;
|
||||
const auto cuda_c_data_type =
|
||||
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
||||
const auto compute_type = CUDA_R_32F;
|
||||
const float sp_alpha = 1.f;
|
||||
// swap a and b, and use CUBLAS_OP_N. see the notes above
|
||||
std::swap(a, b);
|
||||
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
|
||||
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
|
||||
// m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
|
||||
// negative axis is used because of the existence of batch matmul.
|
||||
const int m = a.size(-1);
|
||||
const int k = a.size(-2);
|
||||
const int n = b.size(-2);
|
||||
const int cublas_lda = m;
|
||||
const int cublas_ldb = k;
|
||||
const int cublas_ldc = m;
|
||||
cublasHandle_t cublas_handle = get_cublas_handle();
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
#else
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
||||
#endif
|
||||
const float sp_beta = 0.f;
|
||||
if (a.sizes().size() == 2 && b.sizes().size() == 2) {
|
||||
CUBLAS_CHECK(cublasGemmEx(
|
||||
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
|
||||
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
|
||||
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
|
||||
compute_type, algo));
|
||||
} else {
|
||||
// batch matmul
|
||||
assert(a.sizes().size() == 3 && b.sizes().size() == 3);
|
||||
|
||||
const long long int cublas_stride_a = m * k;
|
||||
const long long int cublas_stride_b = k * n;
|
||||
const long long int cublas_stride_c = m * n;
|
||||
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
|
||||
cublas_handle, cublas_trans_a, cublas_trans_b, m,
|
||||
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
|
||||
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
|
||||
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
|
||||
a.size(0), compute_type, algo));
|
||||
}
|
||||
}
|
||||
|
||||
20
backend-python/rwkv_pip/beta/cuda/wrapper.cpp
vendored
20
backend-python/rwkv_pip/beta/cuda/wrapper.cpp
vendored
@@ -118,7 +118,9 @@ void mm8_one(int64_t N, int64_t M,
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
||||
#ifndef DISABLE_CUBLAS_GEMM
|
||||
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||||
#endif
|
||||
|
||||
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
Tensor v_mix, Tensor r_mix, Tensor kw,
|
||||
@@ -134,6 +136,16 @@ Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
|
||||
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out);
|
||||
|
||||
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
|
||||
Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix,
|
||||
Tensor r_mix, Tensor kw,
|
||||
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx,
|
||||
Tensor rw,
|
||||
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
|
||||
/* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v,
|
||||
/* imm */ Tensor r, /* imm */ Tensor s1,
|
||||
/* out */ Tensor x_plus_out, /* out */ Tensor s2);
|
||||
|
||||
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||||
/* imm */ Tensor buf,
|
||||
@@ -148,8 +160,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
||||
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
||||
m.def("mm8_one", &mm8_one, "mm8 one");
|
||||
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
|
||||
m.def("gemm_fp16_cublas", &gemm_fp16_cublas_tensor, "gemv fp16 cublas");
|
||||
m.def("att_one", &att_one, "att one");
|
||||
m.def("att_one_v5", &att_one_v5, "att one v5");
|
||||
m.def("att_seq", &att_seq, "att seq");
|
||||
m.def("ffn_seq", &ffn_seq, "ffn seq");
|
||||
m.def("ffn_one", &ffn_one, "ffn one");
|
||||
@@ -159,8 +172,9 @@ TORCH_LIBRARY(rwkv, m) {
|
||||
m.def("wkv_forward", wkv_forward);
|
||||
m.def("mm8_seq", mm8_seq);
|
||||
m.def("mm8_one", mm8_one);
|
||||
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
|
||||
m.def("gemm_fp16_cublas", gemm_fp16_cublas_tensor);
|
||||
m.def("att_one", att_one);
|
||||
m.def("att_one_v5", &att_one_v5);
|
||||
m.def("att_seq", att_seq);
|
||||
m.def("ffn_seq", ffn_seq);
|
||||
m.def("ffn_one", ffn_one);
|
||||
|
||||
Reference in New Issue
Block a user