RWKV-Runner/backend-python/rwkv_pip/cuda/util.h
2023-10-03 13:33:55 +08:00

8 lines
240 B
C++
Vendored

#include "ATen/ATen.h"
#include <cuda_fp16.h>
template <typename T> T *data_ptr(torch::Tensor x) { return x.data_ptr<T>(); }
template <> inline half *data_ptr(torch::Tensor x) {
return reinterpret_cast<half *>(x.data_ptr<at::Half>());
}