upgrade to rwkv 0.8.22 (rwkv6 support)
This commit is contained in:
parent
bea3c29c1c
commit
0063c171f3
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@ -109,6 +109,7 @@ jobs:
|
||||
go install github.com/wailsapp/wails/v2/cmd/wails@latest
|
||||
rm ./backend-python/rwkv_pip/wkv_cuda.pyd
|
||||
rm ./backend-python/rwkv_pip/rwkv5.pyd
|
||||
rm ./backend-python/rwkv_pip/rwkv6.pyd
|
||||
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
|
||||
rm ./backend-python/get-pip.py
|
||||
make
|
||||
@ -141,6 +142,7 @@ jobs:
|
||||
go install github.com/wailsapp/wails/v2/cmd/wails@latest
|
||||
rm ./backend-python/rwkv_pip/wkv_cuda.pyd
|
||||
rm ./backend-python/rwkv_pip/rwkv5.pyd
|
||||
rm ./backend-python/rwkv_pip/rwkv6.pyd
|
||||
rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd
|
||||
rm ./backend-python/get-pip.py
|
||||
make
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -18,7 +18,7 @@ __pycache__
|
||||
/cmd-helper.bat
|
||||
/install-py-dep.bat
|
||||
/backend-python/wkv_cuda
|
||||
/backend-python/rwkv5
|
||||
/backend-python/rwkv*
|
||||
*.exe
|
||||
*.old
|
||||
.DS_Store
|
||||
|
87
backend-python/rwkv_pip/cuda/rwkv6.cu
vendored
Normal file
87
backend-python/rwkv_pip/cuda/rwkv6.cu
vendored
Normal file
@ -0,0 +1,87 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
#include "ATen/ATen.h"
|
||||
typedef at::BFloat16 bf16;
|
||||
typedef at::Half fp16;
|
||||
typedef float fp32;
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
||||
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
||||
F *__restrict__ const _y)
|
||||
{
|
||||
const int b = blockIdx.x / H;
|
||||
const int h = blockIdx.x % H;
|
||||
const int i = threadIdx.x;
|
||||
_u += h*_N_;
|
||||
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
||||
|
||||
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
||||
|
||||
float state[_N_];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < _N_; j++)
|
||||
state[j] = _state[j];
|
||||
|
||||
__syncthreads();
|
||||
u[i] = float(_u[i]);
|
||||
__syncthreads();
|
||||
|
||||
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
||||
{
|
||||
__syncthreads();
|
||||
w[i] = _w[t];
|
||||
r[i] = float(_r[t]);
|
||||
k[i] = float(_k[t]);
|
||||
__syncthreads();
|
||||
|
||||
const float v = float(_v[t]);
|
||||
float y = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < _N_; j+=4)
|
||||
{
|
||||
const float4& r_ = (float4&)(r[j]);
|
||||
const float4& k_ = (float4&)(k[j]);
|
||||
const float4& w_ = (float4&)(w[j]);
|
||||
const float4& u_ = (float4&)(u[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 x;
|
||||
|
||||
x.x = k_.x * v;
|
||||
x.y = k_.y * v;
|
||||
x.z = k_.z * v;
|
||||
x.w = k_.w * v;
|
||||
|
||||
y += r_.x * (u_.x * x.x + s.x);
|
||||
y += r_.y * (u_.y * x.y + s.y);
|
||||
y += r_.z * (u_.z * x.z + s.z);
|
||||
y += r_.w * (u_.w * x.w + s.w);
|
||||
|
||||
s.x = s.x * w_.x + x.x;
|
||||
s.y = s.y * w_.y + x.y;
|
||||
s.z = s.z * w_.z + x.z;
|
||||
s.w = s.w * w_.w + x.w;
|
||||
}
|
||||
_y[t] = F(y);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < _N_; j++)
|
||||
_state[j] = state[j];
|
||||
}
|
||||
|
||||
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
||||
{
|
||||
assert(H*_N_ == C);
|
||||
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
||||
}
|
||||
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
||||
{
|
||||
assert(H*_N_ == C);
|
||||
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
||||
}
|
||||
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
||||
{
|
||||
assert(H*_N_ == C);
|
||||
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
||||
}
|
34
backend-python/rwkv_pip/cuda/rwkv6_op.cpp
vendored
Normal file
34
backend-python/rwkv_pip/cuda/rwkv6_op.cpp
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/ATen.h"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
typedef at::BFloat16 bf16;
|
||||
typedef at::Half fp16;
|
||||
typedef float fp32;
|
||||
|
||||
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
||||
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
||||
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
||||
|
||||
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||
}
|
||||
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
||||
}
|
||||
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
|
||||
m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
|
||||
m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
|
||||
}
|
||||
TORCH_LIBRARY(rwkv6, m) {
|
||||
m.def("forward_bf16", forward_bf16);
|
||||
m.def("forward_fp16", forward_fp16);
|
||||
m.def("forward_fp32", forward_fp32);
|
||||
}
|
818
backend-python/rwkv_pip/model.py
vendored
818
backend-python/rwkv_pip/model.py
vendored
File diff suppressed because it is too large
Load Diff
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
Binary file not shown.
BIN
backend-python/rwkv_pip/rwkv6.pyd
vendored
Normal file
BIN
backend-python/rwkv_pip/rwkv6.pyd
vendored
Normal file
Binary file not shown.
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
Binary file not shown.
Loading…
Reference in New Issue
Block a user