add rwkv-cuda-beta support (faster)

This commit is contained in:
josc146
2023-08-14 22:07:15 +08:00
parent da68926e9c
commit 8a13bd3c1e
20 changed files with 2550 additions and 20 deletions

View File

@@ -0,0 +1,7 @@
#include "ATen/ATen.h"
#include <cuda_fp16.h>
template <typename T> T *data_ptr(torch::Tensor x) { return x.data_ptr<T>(); }
template <> inline half *data_ptr(torch::Tensor x) {
return reinterpret_cast<half *>(x.data_ptr<at::Half>());
}