upgrade to rwkv 0.8.20
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#define CUBLAS_CHECK(condition) \
|
||||
for (cublasStatus_t _cublas_check_status = (condition); \
|
||||
@@ -18,26 +20,13 @@
|
||||
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
||||
" at " + std::to_string(__LINE__));
|
||||
|
||||
cublasHandle_t get_cublas_handle() {
|
||||
static cublasHandle_t cublas_handle = []() {
|
||||
cublasHandle_t handle = nullptr;
|
||||
CUBLAS_CHECK(cublasCreate(&handle));
|
||||
#if CUDA_VERSION < 11000
|
||||
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
|
||||
#else
|
||||
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
#endif // CUDA_VERSION < 11000
|
||||
return handle;
|
||||
}();
|
||||
return cublas_handle;
|
||||
}
|
||||
|
||||
/*
|
||||
NOTE: blas gemm is column-major by default, but we need row-major output.
|
||||
The data of row-major, transposed matrix is exactly the same as the
|
||||
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
||||
*/
|
||||
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
const auto cuda_data_type = CUDA_R_16F;
|
||||
const auto cuda_c_data_type =
|
||||
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
||||
@@ -55,7 +44,7 @@ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||
const int cublas_lda = m;
|
||||
const int cublas_ldb = k;
|
||||
const int cublas_ldc = m;
|
||||
cublasHandle_t cublas_handle = get_cublas_handle();
|
||||
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
|
||||
4
backend-python/rwkv_pip/cuda/rwkv5_op.cpp
vendored
4
backend-python/rwkv_pip/cuda/rwkv5_op.cpp
vendored
@@ -1,5 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/ATen.h"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
typedef at::BFloat16 bf16;
|
||||
typedef at::Half fp16;
|
||||
typedef float fp32;
|
||||
@@ -9,12 +10,15 @@ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *
|
||||
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
||||
|
||||
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||
}
|
||||
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
||||
}
|
||||
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user