#include #include "ATen/ATen.h" typedef at::BFloat16 bf16; void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); } void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &forward, "wkv6state forward"); m.def("backward", &backward, "wkv6state backward"); } TORCH_LIBRARY(wkv6state, m) { m.def("forward", forward); m.def("backward", backward); }