upgrade to rwkv 0.8.20

This commit is contained in:
josc146
2023-11-03 23:27:14 +08:00
parent 35e92d2aef
commit 1f81a1e5a8
6 changed files with 188 additions and 359 deletions

View File

@@ -3,6 +3,8 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#define CUBLAS_CHECK(condition) \
for (cublasStatus_t _cublas_check_status = (condition); \
@@ -18,26 +20,13 @@
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
" at " + std::to_string(__LINE__));
cublasHandle_t get_cublas_handle() {
static cublasHandle_t cublas_handle = []() {
cublasHandle_t handle = nullptr;
CUBLAS_CHECK(cublasCreate(&handle));
#if CUDA_VERSION < 11000
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#else
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
return handle;
}();
return cublas_handle;
}
/*
NOTE: blas gemm is column-major by default, but we need row-major output.
The data of row-major, transposed matrix is exactly the same as the
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
*/
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
const auto cuda_data_type = CUDA_R_16F;
const auto cuda_c_data_type =
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
@@ -55,7 +44,7 @@ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
const int cublas_lda = m;
const int cublas_ldb = k;
const int cublas_ldc = m;
cublasHandle_t cublas_handle = get_cublas_handle();
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
#if CUDA_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;