23 lines
1.4 KiB
C++
23 lines
1.4 KiB
C++
|
#include <torch/extension.h>
|
||
|
#include "ATen/ATen.h"
|
||
|
typedef at::BFloat16 bf16;
|
||
|
|
||
|
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
||
|
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
|
||
|
|
||
|
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||
|
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||
|
}
|
||
|
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
|
||
|
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
|
||
|
}
|
||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||
|
m.def("forward", &forward, "wkv5 forward");
|
||
|
m.def("backward", &backward, "wkv5 backward");
|
||
|
}
|
||
|
|
||
|
TORCH_LIBRARY(wkv5, m) {
|
||
|
m.def("forward", forward);
|
||
|
m.def("backward", backward);
|
||
|
}
|