add rwkv-cuda-beta support (faster)
This commit is contained in:
7
backend-python/rwkv_pip/beta/cuda/util.h
vendored
Normal file
7
backend-python/rwkv_pip/beta/cuda/util.h
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
template <typename T> T *data_ptr(torch::Tensor x) { return x.data_ptr<T>(); }
|
||||
template <> inline half *data_ptr(torch::Tensor x) {
|
||||
return reinterpret_cast<half *>(x.data_ptr<at::Half>());
|
||||
}
|
||||
Reference in New Issue
Block a user