diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5d7d7e2..55485bf 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -109,6 +109,7 @@ jobs: go install github.com/wailsapp/wails/v2/cmd/wails@latest rm ./backend-python/rwkv_pip/wkv_cuda.pyd rm ./backend-python/rwkv_pip/rwkv5.pyd + rm ./backend-python/rwkv_pip/rwkv6.pyd rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd rm ./backend-python/get-pip.py make @@ -141,6 +142,7 @@ jobs: go install github.com/wailsapp/wails/v2/cmd/wails@latest rm ./backend-python/rwkv_pip/wkv_cuda.pyd rm ./backend-python/rwkv_pip/rwkv5.pyd + rm ./backend-python/rwkv_pip/rwkv6.pyd rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd rm ./backend-python/get-pip.py make diff --git a/.gitignore b/.gitignore index b6ebf31..246fbdf 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,7 @@ __pycache__ /cmd-helper.bat /install-py-dep.bat /backend-python/wkv_cuda -/backend-python/rwkv5 +/backend-python/rwkv* *.exe *.old .DS_Store diff --git a/backend-python/rwkv_pip/cuda/rwkv6.cu b/backend-python/rwkv_pip/cuda/rwkv6.cu new file mode 100644 index 0000000..91d9cd3 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/rwkv6.cu @@ -0,0 +1,87 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; +typedef at::Half fp16; +typedef float fp32; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + + float state[_N_]; + #pragma unroll + for (int j = 0; j < _N_; j++) + state[j] = _state[j]; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + w[i] = _w[t]; + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } + #pragma unroll + for (int j = 0; j < _N_; j++) + _state[j] = state[j]; +} + +void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); +} +void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y) +{ + assert(H*_N_ == C); + kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); +} +void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y) +{ + assert(H*_N_ == C); + kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); +} diff --git a/backend-python/rwkv_pip/cuda/rwkv6_op.cpp b/backend-python/rwkv_pip/cuda/rwkv6_op.cpp new file mode 100644 index 0000000..3701ea0 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/rwkv6_op.cpp @@ -0,0 +1,34 @@ +#include +#include "ATen/ATen.h" +#include +typedef at::BFloat16 bf16; +typedef at::Half fp16; +typedef float fp32; + +void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); +void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); + +void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); + cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); + cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); + cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16"); + m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16"); + m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32"); +} +TORCH_LIBRARY(rwkv6, m) { + m.def("forward_bf16", forward_bf16); + m.def("forward_fp16", forward_fp16); + m.def("forward_fp32", forward_fp32); +} diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py index 4173c4f..cdfec6c 100644 --- a/backend-python/rwkv_pip/model.py +++ b/backend-python/rwkv_pip/model.py @@ -5,6 +5,7 @@ from typing import Optional import types, gc, os, time, re import torch +import torch.nn as nn from torch.nn import functional as F torch.backends.cudnn.benchmark = True @@ -80,7 +81,7 @@ else: if os.environ.get("RWKV_CUDA_ON") == "1": DISABLE_CUBLAS_GEMM = False - from torch.utils.cpp_extension import load # L581 + from torch.utils.cpp_extension import load if LoadPreCompileLibrary("wkv_cuda") is False: try: @@ -374,6 +375,11 @@ class RWKV(MyModule): if len(w[x].shape) > 1: if w[x].shape[1] > 1: self.version = max(5.2, self.version) + if "time_maa" in x: + self.version = max(6, self.version) + if int(self.version) == 6 and "time_faaaa" in x: + args.n_head = w[x].shape[0] + prxxx(f"Model detected: v{self.version:.1f}") ####################### Compute strategy @@ -524,22 +530,24 @@ class RWKV(MyModule): ): w[x] = w[x].t() - if ".time_decay" in x: # need fp32 for this + if ".time_decay" in x and "_w" not in x: # need fp32 for this if self.version == 4: w[x] = -torch.exp(w[x].float()) elif int(self.version) == 5: w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1) if self.version == 5.2: w[x] = w[x].reshape(args.n_head, -1, 1) + elif self.version == 6.0: + w[x] = w[x].float().reshape(args.n_head, -1, 1) elif ".time_first" in x: # need fp32 for this if self.version == 4: w[x] = w[x].float() - elif int(self.version) == 5: + elif int(self.version) in [5, 6]: if REAL_TIME_FIRST: w[x] = w[x].float().reshape(-1, 1, 1) else: w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1) - if self.version == 5.2: + if self.version in [5.2, 6.0]: w[x] = w[x].reshape(args.n_head, -1, 1) elif ".ln_x" in x: # need fp32 for group_norm w[x] = w[x].float() @@ -695,22 +703,85 @@ class RWKV(MyModule): assert u.is_contiguous() assert state.is_contiguous() - y = torch.empty( - (B, T, C), - device=w.device, - dtype=r.dtype, - memory_format=torch.contiguous_format, - ) - if r.dtype == torch.bfloat16: - rwkv5.forward_bf16(B, T, C, H, state, r, k, v, w, u, y) - elif r.dtype == torch.float16: - rwkv5.forward_fp16(B, T, C, H, state, r, k, v, w, u, y) - elif r.dtype == torch.float32: - rwkv5.forward_fp32(B, T, C, H, state, r, k, v, w, u, y) - return y, state + y = torch.empty( + (B, T, C), + device=w.device, + dtype=r.dtype, + memory_format=torch.contiguous_format, + ) + if r.dtype == torch.bfloat16: + rwkv5.forward_bf16(B, T, C, H, state, r, k, v, w, u, y) + elif r.dtype == torch.float16: + rwkv5.forward_fp16(B, T, C, H, state, r, k, v, w, u, y) + elif r.dtype == torch.float32: + rwkv5.forward_fp32(B, T, C, H, state, r, k, v, w, u, y) + return y, state self.RWKV_5 = RWKV_5 + if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == "1": + HEAD_SIZE = args.n_att // args.n_head + if LoadPreCompileLibrary("rwkv6") is True: + rwkv6 = torch.ops.rwkv6 + else: + rwkv6 = load( + name="rwkv6", + sources=[ + f"{current_path}/cuda/rwkv6_op.cpp", + f"{current_path}/cuda/rwkv6.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3" if os.name != "nt" else "", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + f"-D_T_={4096}", + ], + ) + + class RWKV_6(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, state, r, k, v, w, u): + with torch.no_grad(): + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert state.dtype == torch.float32 + assert w.dtype == torch.float32 + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + eew = torch.exp(-torch.exp(w.float())).contiguous() + + y = torch.empty( + (B, T, C), + device=w.device, + dtype=r.dtype, + memory_format=torch.contiguous_format, + ) + if r.dtype == torch.bfloat16: + rwkv6.forward_bf16( + B, T, C, H, state, r, k, v, eew, u, y + ) + elif r.dtype == torch.float16: + rwkv6.forward_fp16( + B, T, C, H, state, r, k, v, eew, u, y + ) + elif r.dtype == torch.float32: + rwkv6.forward_fp32( + B, T, C, H, state, r, k, v, eew, u, y + ) + return y, state + + self.RWKV_6 = RWKV_6 + gc.collect() if "cuda" in args.strategy_string: torch.cuda.empty_cache() @@ -718,6 +789,9 @@ class RWKV(MyModule): def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u): return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u) + def RUN_RWKV_6(self, B, T, C, H, state, r, k, v, w, u): + return self.RWKV_6.apply(B, T, C, H, state, r, k, v, w, u) + ######################################################################################################## @MyFunction @@ -750,12 +824,10 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) - vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry))) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 out = r * matmul(vx, vw, vmx, vrx, vmy, vry) return x + out, xx - ######################################################################################################## - @MyFunction def ffn_seq( self, @@ -787,7 +859,78 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) - vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry))) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) + return x + out, xx[-1, :] + + @MyFunction + def ffn_one_v6( + self, + x, + sx, + ln_w, + ln_b, + k_maa, + r_maa, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = sx - xx + kx = xx + sx * k_maa + rx = xx + sx * r_maa + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) + return x + out, xx + + @MyFunction + def ffn_seq_v6( + self, + x, + sx, + ln_w, + ln_b, + k_maa, + r_maa, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + sx = sx - xx + kx = xx + sx * k_maa + rx = xx + sx * r_maa + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 out = r * matmul(vx, vw, vmx, vrx, vmy, vry) return x + out, xx[-1, :] @@ -851,8 +994,6 @@ class RWKV(MyModule): out = matmul(r * wkv, ow, omx, orx, omy, ory) return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p - ######################################################################################################## - @MyFunction def att_seq( self, @@ -962,11 +1103,11 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H - r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S) - k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1) - v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N) a = matmul(k, v) out = r @ (t_first * a + s) @@ -974,7 +1115,7 @@ class RWKV(MyModule): out = out.flatten() out = F.group_norm( - out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5 ).squeeze(0) out = out.to(dtype=x.dtype) out = matmul(out, ow, omx, orx, omy, ory) @@ -1024,7 +1165,7 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H T = x.shape[0] w = t_decay.reshape(-1, 1) @@ -1042,26 +1183,25 @@ class RWKV(MyModule): r = ( matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) k = ( matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) - .view(T, H, S) - .transpose(0, 1) - .transpose(-2, -1) + .view(T, H, N) + .permute(1, 2, 0) ) v = ( matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) out = ((r @ k) * w) @ v + (r @ s) * wb s = ws * s + (k * wk) @ v - out = out.transpose(0, 1).contiguous().reshape(T, H * S) - out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.transpose(0, 1).contiguous().reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) out = out.to(dtype=x.dtype) out = matmul(out, ow, omx, orx, omy, ory) @@ -1118,11 +1258,11 @@ class RWKV(MyModule): gx = xx * g_mix + sx * (1 - g_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H - r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S) - k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1) - v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) a = matmul(k, v) @@ -1131,7 +1271,7 @@ class RWKV(MyModule): out = out.flatten() out = F.group_norm( - out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5 ).squeeze(0) out = out.to(dtype=x.dtype) * g out = matmul(out, ow, omx, orx, omy, ory) @@ -1188,7 +1328,7 @@ class RWKV(MyModule): gx = xx * g_mix + sx * (1 - g_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H T = x.shape[0] w = t_decay.reshape(-1, 1) @@ -1206,18 +1346,17 @@ class RWKV(MyModule): r = ( matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) k = ( matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) - .view(T, H, S) - .transpose(0, 1) - .transpose(-2, -1) + .view(T, H, N) + .permute(1, 2, 0) ) v = ( matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) @@ -1225,8 +1364,8 @@ class RWKV(MyModule): out = ((r @ k) * w) @ v + (r @ s) * wb s = ws * s + (k * wk) @ v - out = out.transpose(0, 1).contiguous().reshape(T, H * S) - out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.transpose(0, 1).contiguous().reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) out = out.to(dtype=x.dtype) * g out = matmul(out, ow, omx, orx, omy, ory) @@ -1284,28 +1423,27 @@ class RWKV(MyModule): gx = xx * g_mix + sx * (1 - g_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H T = x.shape[0] r = ( matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) k = ( matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) - .view(T, H, S) - .transpose(0, 1) - .transpose(-2, -1) + .view(T, H, N) + .permute(1, 2, 0) ) v = ( matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) - out = torch.empty((T, H, S), dtype=r.dtype, device=r.device) + out = torch.empty((T, H, N), dtype=r.dtype, device=r.device) for t in range(T): rt = r[:, t : t + 1, :] kt = k[:, :, t : t + 1] @@ -1314,8 +1452,198 @@ class RWKV(MyModule): out[t] = (rt @ (t_first * at + s)).squeeze(1) s = at + t_decay * s - out = out.reshape(T, H * S) - out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx[-1, :], s + + ######################################################################################################## + + @MyFunction + def att_one_v6_0( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + + sx = sx - xx + xxx = xx + sx * x_maa + xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) + xxx = torch.bmm(xxx, tm_w2).view(5, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + wx = xx + sx * (w_maa + mw) + kx = xx + sx * (k_maa + mk) + vx = xx + sx * (v_maa + mv) + rx = xx + sx * (r_maa + mr) + gx = xx + sx * (g_maa + mg) + + H = t_decay.shape[0] + N = x.shape[-1] // H + + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + w = t_decay + (torch.tanh(wx @ td_w1) @ td_w2).float().view(H, N, 1) + w = torch.exp(-torch.exp(w.float())) + + a = matmul(k, v) + out = r @ (t_first * a + s) + s = a + w * s + + out = out.flatten() + out = F.group_norm( + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5 + ).squeeze(0) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx, s + + @MyFunction + def att_seq_v6_0( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - xx + xxx = xx + sx * x_maa + xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + wx = xx + sx * (w_maa + mw) + kx = xx + sx * (k_maa + mk) + vx = xx + sx * (v_maa + mv) + rx = xx + sx * (r_maa + mr) + gx = xx + sx * (g_maa + mg) + + r = ( + matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + k = ( + matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + .view(T, H, N) + .permute(1, 2, 0) + ) + v = ( + matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + w = t_decay.view(1, H, N, 1) + (torch.tanh(wx @ td_w1) @ td_w2).float().view( + T, H, N, 1 + ) + w = torch.exp(-torch.exp(w.float())) + out = torch.empty((T, H, N), dtype=r.dtype, device=r.device) + for t in range(T): + rt = r[:, t : t + 1, :] + kt = k[:, :, t : t + 1] + vt = v[:, t : t + 1, :] + at = matmul(kt, vt) + out[t] = (rt @ (t_first * at + s)).squeeze(1) + s = at + w[t] * s + + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) out = out.to(dtype=x.dtype) * g out = matmul(out, ow, omx, orx, omy, ory) @@ -1376,8 +1704,8 @@ class RWKV(MyModule): out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory) return x + out, xx[-1, :], aa, bb, pp - # NOTE: decorate with @MyFunction causes JIT error - def cuda_att_seq_v5_2( + @MyFunction + def v5_2_before( self, x, sx, @@ -1425,35 +1753,303 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) gx = xx * g_mix + sx * (1 - g_mix) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + return r, k, v, g, xx[-1, :], s.transpose(-1, -2).contiguous() + + @MyFunction + def v5_2_after( + self, t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory + ): H = t_decay.shape[0] N = x.shape[-1] // H T = x.shape[0] + s = s.transpose(-1, -2) + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xxx, s + + def cuda_att_seq_v5_2( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + r, k, v, g, xxx, ss = self.v5_2_before( + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ) + + out, s = self.RUN_RWKV_5( + 1, T, self.args.n_att, H, ss, r, k, v, w=t_decay, u=t_first + ) + + return self.v5_2_after( + t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory + ) + + @MyFunction + def v6_0_before( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - xx + xxx = xx + sx * x_maa + xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + wx = xx + sx * (w_maa + mw) + kx = xx + sx * (k_maa + mk) + vx = xx + sx * (v_maa + mv) + rx = xx + sx * (r_maa + mr) + gx = xx + sx * (g_maa + mg) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) - out, s = self.RUN_RWKV_5( - 1, - T, - self.args.n_att, - H, - s.transpose(-1, -2).contiguous(), - r, - k, - v, - w=t_decay, - u=t_first, + w = t_decay.view(1, H, N, 1) + ( + torch.tanh(wx @ td_w1) @ td_w2 + ).float().view(T, H, N, 1) + + return r, k, v, g, w, xx[-1, :], s.transpose(-1, -2).contiguous() + + def cuda_att_seq_v6_0( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + r, k, v, g, w, xxx, ss = self.v6_0_before( + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, ) - s = s.transpose(-1, -2) - out = out.reshape(T, H * N) - out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) - out = out.to(dtype=x.dtype) * g - out = matmul(out, ow, omx, orx, omy, ory) - - return x + out, xx[-1, :], s + out, s = self.RUN_RWKV_6( + 1, T, self.args.n_att, H, ss, r, k, v, w=w, u=t_first + ) + return self.v5_2_after( + t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory + ) ######################################################################################################## @@ -1498,7 +2094,7 @@ class RWKV(MyModule): state[i * 5 + 4] = torch.zeros( args.n_embd, dtype=atype, requires_grad=False, device=dev ).contiguous() - elif int(self.version) == 5: + elif int(self.version) in [5, 6]: state = [None] * args.n_layer * 3 for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx dd = self.strategy[i] @@ -1549,7 +2145,13 @@ class RWKV(MyModule): ATT = self.att_seq_v5_2 if cuda_applicable: ATT = self.cuda_att_seq_v5_2 + elif self.version == 6.0: + ATT = self.att_seq_v6_0 + if cuda_applicable: + ATT = self.cuda_att_seq_v6_0 FFN = self.ffn_seq + if self.version >= 6.0: + FFN = self.ffn_seq_v6 else: ATT = self.att_one if self.version == 5: @@ -1558,7 +2160,11 @@ class RWKV(MyModule): ATT = self.att_one_v5_1 elif self.version == 5.2: ATT = self.att_one_v5_1 # same as v5.1 + elif self.version == 6.0: + ATT = self.att_one_v6_0 FFN = self.ffn_one + if self.version >= 6.0: + FFN = self.ffn_one_v6 x = x.to(dtype=atype, device=dev) @@ -1587,7 +2193,7 @@ class RWKV(MyModule): orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x - if self.version == 5.1 or self.version == 5.2: + if self.version in [5.1, 5.2, 6.0]: gw = w[f"{att}gate.weight"] if dd.stream: gw = gw.to(device=dev, non_blocking=True) @@ -1671,7 +2277,7 @@ class RWKV(MyModule): omy, ory, ) - elif self.version == 5.1 or self.version == 5.2: + elif self.version in [5.1, 5.2]: x, state[i * 3 + 0], state[i * 3 + 1] = ATT( x, state[i * 3 + 0], @@ -1712,8 +2318,57 @@ class RWKV(MyModule): omy, ory, ) + elif self.version == 6.0: + x, state[i * 3 + 0], state[i * 3 + 1] = ATT( + x, + state[i * 3 + 0], + state[i * 3 + 1], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}ln_x.weight"], + w[f"{att}ln_x.bias"], + w[f"{att}time_maa_x"], + w[f"{att}time_maa_w"], + w[f"{att}time_maa_k"], + w[f"{att}time_maa_v"], + w[f"{att}time_maa_r"], + w[f"{att}time_maa_g"], + w[f"{att}time_maa_w1"], + w[f"{att}time_maa_w2"], + w[f"{att}time_decay_w1"], + w[f"{att}time_decay_w2"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ) if dd.stream: del kw, vw, rw, ow + if self.version in [5.1, 5.2, 6.0]: + del gw kw = w[f"{ffn}key.weight"] vw = w[f"{ffn}value.weight"] @@ -1736,31 +2391,56 @@ class RWKV(MyModule): rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x if self.version == 4: offset = i * 5 + 4 - elif int(self.version) == 5: + elif int(self.version) in [5, 6]: offset = i * 3 + 2 - x, state[offset] = FFN( - x, - state[offset], - w[f"{bbb}ln2.weight"], - w[f"{bbb}ln2.bias"], - w[f"{ffn}time_mix_k"], - w[f"{ffn}time_mix_r"], - kw, - vw, - rw, - kmx, - krx, - kmy, - kry, - vmx, - vrx, - vmy, - vry, - rmx, - rrx, - rmy, - rry, - ) + if self.version < 6.0: + x, state[offset] = FFN( + x, + state[offset], + w[f"{bbb}ln2.weight"], + w[f"{bbb}ln2.bias"], + w[f"{ffn}time_mix_k"], + w[f"{ffn}time_mix_r"], + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ) + else: + x, state[offset] = FFN( + x, + state[offset], + w[f"{bbb}ln2.weight"], + w[f"{bbb}ln2.bias"], + w[f"{ffn}time_maa_k"], + w[f"{ffn}time_maa_r"], + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ) if dd.stream: del kw, vw, rw diff --git a/backend-python/rwkv_pip/rwkv5.pyd b/backend-python/rwkv_pip/rwkv5.pyd index 24d7c1f..5b59dde 100644 Binary files a/backend-python/rwkv_pip/rwkv5.pyd and b/backend-python/rwkv_pip/rwkv5.pyd differ diff --git a/backend-python/rwkv_pip/rwkv6.pyd b/backend-python/rwkv_pip/rwkv6.pyd new file mode 100644 index 0000000..480a03e Binary files /dev/null and b/backend-python/rwkv_pip/rwkv6.pyd differ diff --git a/backend-python/rwkv_pip/wkv_cuda.pyd b/backend-python/rwkv_pip/wkv_cuda.pyd index 8e668ed..9d7a281 100644 Binary files a/backend-python/rwkv_pip/wkv_cuda.pyd and b/backend-python/rwkv_pip/wkv_cuda.pyd differ