diff --git a/backend-python/rwkv_pip/cuda/gemm_fp16_cublas.cpp b/backend-python/rwkv_pip/cuda/gemm_fp16_cublas.cpp index db48fcf..6823ce8 100644 --- a/backend-python/rwkv_pip/cuda/gemm_fp16_cublas.cpp +++ b/backend-python/rwkv_pip/cuda/gemm_fp16_cublas.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include #define CUBLAS_CHECK(condition) \ for (cublasStatus_t _cublas_check_status = (condition); \ @@ -18,26 +20,13 @@ "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \ " at " + std::to_string(__LINE__)); -cublasHandle_t get_cublas_handle() { - static cublasHandle_t cublas_handle = []() { - cublasHandle_t handle = nullptr; - CUBLAS_CHECK(cublasCreate(&handle)); -#if CUDA_VERSION < 11000 - CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); -#else - CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); -#endif // CUDA_VERSION < 11000 - return handle; - }(); - return cublas_handle; -} - /* NOTE: blas gemm is column-major by default, but we need row-major output. The data of row-major, transposed matrix is exactly the same as the column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T */ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const auto cuda_data_type = CUDA_R_16F; const auto cuda_c_data_type = c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; @@ -55,7 +44,7 @@ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { const int cublas_lda = m; const int cublas_ldb = k; const int cublas_ldc = m; - cublasHandle_t cublas_handle = get_cublas_handle(); + cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); #if CUDA_VERSION >= 11000 cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; diff --git a/backend-python/rwkv_pip/cuda/rwkv5_op.cpp b/backend-python/rwkv_pip/cuda/rwkv5_op.cpp index 5471bcf..e61e604 100644 --- a/backend-python/rwkv_pip/cuda/rwkv5_op.cpp +++ b/backend-python/rwkv_pip/cuda/rwkv5_op.cpp @@ -1,5 +1,6 @@ #include #include "ATen/ATen.h" +#include typedef at::BFloat16 bf16; typedef at::Half fp16; typedef float fp32; @@ -9,12 +10,15 @@ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 * void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); } void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); } void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); } diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py index 050bb98..4173c4f 100644 --- a/backend-python/rwkv_pip/model.py +++ b/backend-python/rwkv_pip/model.py @@ -171,10 +171,86 @@ if os.environ.get("RWKV_CUDA_ON") == "1": else: os.environ["RWKV_CUDA_ON"] = "0" -if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM: + +@MyStatic +def torch_mm8_seq(x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + +@MyStatic +def torch_mm8_one(x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + +if os.environ.get("RWKV_CUDA_ON") == "1": @MyStatic - def gemm(a, b, output_dtype: Optional[torch.dtype] = None): + def mm8_seq(x, w, mx, rx, my, ry): + if w.device.type == "cuda" and x.dtype == torch.float16: + B, N, M = x.shape[0], w.shape[0], w.shape[1] + return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry) + else: + return torch_mm8_seq(x, w, mx, rx, my, ry) + + @MyStatic + def mm8_one(x, w, mx, rx, my, ry): + if w.device.type == "cuda": + N, M = w.shape[0], w.shape[1] + return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) + else: + return torch_mm8_one(x, w, mx, rx, my, ry) + +else: + + @MyStatic + def mm8_seq(x, w, mx, rx, my, ry): + return torch_mm8_seq(x, w, mx, rx, my, ry) + + @MyStatic + def mm8_one(x, w, mx, rx, my, ry): + return torch_mm8_one(x, w, mx, rx, my, ry) + + +def mm8( + x: torch.Tensor, + w: torch.Tensor, + mx: torch.Tensor, + rx: torch.Tensor, + my: torch.Tensor, + ry: torch.Tensor, +): + if len(x.shape) == 1: + return mm8_one(x, w, mx, rx, my, ry) + return mm8_seq(x, w, mx, rx, my, ry) + + +def matmul( + a, + b, + mx: Optional[torch.Tensor] = None, + rx: Optional[torch.Tensor] = None, + my: Optional[torch.Tensor] = None, + ry: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if output_dtype is None: + output_dtype = a.dtype + if b.dtype in [torch.float16, torch.bfloat16, torch.float32]: + assert a.dtype == b.dtype + return matmul_float(a, b, output_dtype=output_dtype) + elif b.dtype == torch.uint8: + assert mx is not None + assert rx is not None + assert my is not None + assert ry is not None + return mm8(a, b, mx, rx, my, ry).to(output_dtype) + else: + raise ValueError("Unsupported dtype") + + +if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM: + + def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None): if output_dtype is None: output_dtype = a.dtype if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda": @@ -203,9 +279,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM: else: - def gemm(a, b, output_dtype: Optional[torch.dtype] = None): - if output_dtype is None: - output_dtype = a.dtype + def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None): return (a @ b).to(output_dtype) @@ -644,42 +718,6 @@ class RWKV(MyModule): def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u): return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u) - @MyFunction - def torch_mm8_seq(self, x, w, mx, rx, my, ry): - return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) - - @MyFunction - def torch_mm8_one(self, x, w, mx, rx, my, ry): - return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) - - if os.environ.get("RWKV_CUDA_ON") == "1": - - @MyFunction - def mm8_seq(self, x, w, mx, rx, my, ry): - if w.device.type == "cuda" and x.dtype == torch.float16: - B, N, M = x.shape[0], w.shape[0], w.shape[1] - return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry) - else: - return self.torch_mm8_seq(x, w, mx, rx, my, ry) - - @MyFunction - def mm8_one(self, x, w, mx, rx, my, ry): - if w.device.type == "cuda": - N, M = w.shape[0], w.shape[1] - return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) - else: - return self.torch_mm8_one(x, w, mx, rx, my, ry) - - else: - - @MyFunction - def mm8_seq(self, x, w, mx, rx, my, ry): - return self.torch_mm8_seq(x, w, mx, rx, my, ry) - - @MyFunction - def mm8_one(self, x, w, mx, rx, my, ry): - return self.torch_mm8_one(x, w, mx, rx, my, ry) - ######################################################################################################## @MyFunction @@ -711,43 +749,9 @@ class RWKV(MyModule): kx = xx * k_mix + sx * (1 - k_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(gemm(rx, rw)) - vx = torch.square(torch.relu(gemm(kx, kw))) - out = r * gemm(vx, vw) - return x + out, xx - - @MyFunction - def ffn_one_i8( - self, - x, - sx, - ln_w, - ln_b, - k_mix, - r_mix, - kw, - vw, - rw, - kmx, - krx, - kmy, - kry, - vmx, - vrx, - vmy, - vry, - rmx, - rrx, - rmy, - rry, - ): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - kx = xx * k_mix + sx * (1 - k_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) - vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry))) - out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)) + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry))) + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) return x + out, xx ######################################################################################################## @@ -782,44 +786,9 @@ class RWKV(MyModule): kx = xx * k_mix + sx * (1 - k_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(gemm(rx, rw)) - vx = torch.square(torch.relu(gemm(kx, kw))) - out = r * gemm(vx, vw) - return x + out, xx[-1, :] - - @MyFunction - def ffn_seq_i8( - self, - x, - sx, - ln_w, - ln_b, - k_mix, - r_mix, - kw, - vw, - rw, - kmx, - krx, - kmy, - kry, - vmx, - vrx, - vmy, - vry, - rmx, - rrx, - rmy, - rry, - ): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - kx = xx * k_mix + sx * (1 - k_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) - vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry))) - out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)) + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry))) + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) return x + out, xx[-1, :] ######################################################################################################## @@ -865,9 +834,9 @@ class RWKV(MyModule): vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(gemm(rx, rw)) - k = gemm(kx, kw, output_dtype=torch.float32) - v = gemm(vx, vw, output_dtype=torch.float32) + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) ww = t_first + k p = torch.maximum(pp, ww) @@ -879,65 +848,7 @@ class RWKV(MyModule): e1 = torch.exp(ww - p) e2 = torch.exp(k - p) - out = gemm(r * wkv, ow) - return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p - - @MyFunction - def att_one_i8( - self, - x, - sx, - aa, - bb, - pp, - ln_w, - ln_b, - k_mix, - v_mix, - r_mix, - t_decay, - t_first, - kw, - vw, - rw, - ow, - kmx, - krx, - kmy, - kry, - vmx, - vrx, - vmy, - vry, - rmx, - rrx, - rmy, - rry, - omx, - orx, - omy, - ory, - ): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) - k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float() - v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float() - - ww = t_first + k - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) - ww = t_decay + pp - p = torch.maximum(ww, k) - e1 = torch.exp(ww - p) - e2 = torch.exp(k - p) - - out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory) + out = matmul(r * wkv, ow, omx, orx, omy, ory) return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p ######################################################################################################## @@ -984,9 +895,9 @@ class RWKV(MyModule): vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(gemm(rx, rw)) - k = gemm(kx, kw, output_dtype=torch.float32) - v = gemm(vx, vw, output_dtype=torch.float32) + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) T = x.shape[0] for t in range(T): @@ -1004,72 +915,7 @@ class RWKV(MyModule): aa = e1 * aa + e2 * vv bb = e1 * bb + e2 pp = p - out = gemm(r * sx, ow) - return x + out, xx[-1, :], aa, bb, pp - - @MyFunction - def att_seq_i8( - self, - x, - sx, - aa, - bb, - pp, - ln_w, - ln_b, - k_mix, - v_mix, - r_mix, - t_decay, - t_first, - kw, - vw, - rw, - ow, - kmx, - krx, - kmy, - kry, - vmx, - vrx, - vmy, - vry, - rmx, - rrx, - rmy, - rry, - omx, - orx, - omy, - ory, - ): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) - k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float() - v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float() - - T = x.shape[0] - for t in range(T): - kk = k[t] - vv = v[t] - ww = t_first + kk - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) - ww = t_decay + pp - p = torch.maximum(ww, kk) - e1 = torch.exp(ww - p) - e2 = torch.exp(kk - p) - aa = e1 * aa + e2 * vv - bb = e1 * bb + e2 - pp = p - out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory) + out = matmul(r * sx, ow, omx, orx, omy, ory) return x + out, xx[-1, :], aa, bb, pp ######################################################################################################## @@ -1118,11 +964,11 @@ class RWKV(MyModule): H = t_decay.shape[0] S = x.shape[-1] // H - r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S) - k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1) - v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S) - a = gemm(k, v) + a = matmul(k, v) out = r @ (t_first * a + s) s = a + t_decay * s @@ -1131,7 +977,7 @@ class RWKV(MyModule): out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b ).squeeze(0) out = out.to(dtype=x.dtype) - out = gemm(out, ow) + out = matmul(out, ow, omx, orx, omy, ory) return x + out, xx, s @@ -1194,14 +1040,22 @@ class RWKV(MyModule): w = w[:, :-T].reshape(-1, T, 2 * T - 1) w = w[:, :, T - 1 :].reshape(H, T, T) - r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + r = ( + matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + ) k = ( - gemm(kx, kw, output_dtype=torch.float32) + matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) .view(T, H, S) .transpose(0, 1) .transpose(-2, -1) ) - v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + v = ( + matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + ) out = ((r @ k) * w) @ v + (r @ s) * wb s = ws * s + (k * wk) @ v @@ -1209,7 +1063,7 @@ class RWKV(MyModule): out = out.transpose(0, 1).contiguous().reshape(T, H * S) out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) out = out.to(dtype=x.dtype) - out = gemm(out, ow) + out = matmul(out, ow, omx, orx, omy, ory) return x + out, xx[-1, :], s @@ -1248,6 +1102,10 @@ class RWKV(MyModule): rrx, rmy, rry, + gmx, + grx, + gmy, + gry, omx, orx, omy, @@ -1262,12 +1120,12 @@ class RWKV(MyModule): H = t_decay.shape[0] S = x.shape[-1] // H - r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S) - k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1) - v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S) - g = F.silu(gemm(gx, gw)) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) - a = gemm(k, v) + a = matmul(k, v) out = r @ (t_first * a + s) s = a + t_decay * s @@ -1276,7 +1134,7 @@ class RWKV(MyModule): out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b ).squeeze(0) out = out.to(dtype=x.dtype) * g - out = gemm(out, ow) + out = matmul(out, ow, omx, orx, omy, ory) return x + out, xx, s @@ -1313,6 +1171,10 @@ class RWKV(MyModule): rrx, rmy, rry, + gmx, + grx, + gmy, + gry, omx, orx, omy, @@ -1342,15 +1204,23 @@ class RWKV(MyModule): w = w[:, :-T].reshape(-1, T, 2 * T - 1) w = w[:, :, T - 1 :].reshape(H, T, T) - r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + r = ( + matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + ) k = ( - gemm(kx, kw, output_dtype=torch.float32) + matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) .view(T, H, S) .transpose(0, 1) .transpose(-2, -1) ) - v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) - g = F.silu(gemm(gx, gw)) + v = ( + matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + ) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) out = ((r @ k) * w) @ v + (r @ s) * wb s = ws * s + (k * wk) @ v @@ -1358,7 +1228,7 @@ class RWKV(MyModule): out = out.transpose(0, 1).contiguous().reshape(T, H * S) out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) out = out.to(dtype=x.dtype) * g - out = gemm(out, ow) + out = matmul(out, ow, omx, orx, omy, ory) return x + out, xx[-1, :], s @@ -1397,6 +1267,10 @@ class RWKV(MyModule): rrx, rmy, rry, + gmx, + grx, + gmy, + gry, omx, orx, omy, @@ -1413,29 +1287,37 @@ class RWKV(MyModule): S = x.shape[-1] // H T = x.shape[0] - r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + r = ( + matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + ) k = ( - gemm(kx, kw, output_dtype=torch.float32) + matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) .view(T, H, S) .transpose(0, 1) .transpose(-2, -1) ) - v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) - g = F.silu(gemm(gx, gw)) + v = ( + matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + ) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) out = torch.empty((T, H, S), dtype=r.dtype, device=r.device) for t in range(T): rt = r[:, t : t + 1, :] kt = k[:, :, t : t + 1] vt = v[:, t : t + 1, :] - at = gemm(kt, vt) + at = matmul(kt, vt) out[t] = (rt @ (t_first * at + s)).squeeze(1) s = at + t_decay * s out = out.reshape(T, H * S) out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) out = out.to(dtype=x.dtype) * g - out = gemm(out, ow) + out = matmul(out, ow, omx, orx, omy, ory) return x + out, xx[-1, :], s @@ -1486,63 +1368,12 @@ class RWKV(MyModule): vx = xx * v_mix + sx * (1 - v_mix) rx = xx * r_mix + sx * (1 - r_mix) - r = torch.sigmoid(gemm(rx, rw)) - k = gemm(kx, kw, output_dtype=torch.float32) - v = gemm(vx, vw, output_dtype=torch.float32) - y, aa, bb, pp = cuda_wkv(T, aa.shape[0], t_decay, t_first, k, v, aa, bb, pp) - - out = gemm(r * y.to(x.dtype), ow) - return x + out, xx[-1, :], aa, bb, pp - - @MyFunction - def cuda_att_seq_i8( - self, - x, - sx, - aa, - bb, - pp, - ln_w, - ln_b, - k_mix, - v_mix, - r_mix, - t_decay, - t_first, - kw, - vw, - rw, - ow, - kmx, - krx, - kmy, - kry, - vmx, - vrx, - vmy, - vry, - rmx, - rrx, - rmy, - rry, - omx, - orx, - omy, - ory, - ): - T, C = x.shape - xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) - k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry) - v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry) + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) - out = self.mm8_seq(r * y, ow, omx, orx, omy, ory) + out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory) return x + out, xx[-1, :], aa, bb, pp # NOTE: decorate with @MyFunction causes JIT error @@ -1578,6 +1409,10 @@ class RWKV(MyModule): rrx, rmy, rry, + gmx, + grx, + gmy, + gry, omx, orx, omy, @@ -1594,10 +1429,10 @@ class RWKV(MyModule): N = x.shape[-1] // H T = x.shape[0] - r = gemm(rx, rw, output_dtype=torch.float32) - k = gemm(kx, kw, output_dtype=torch.float32) - v = gemm(vx, vw, output_dtype=torch.float32) - g = F.silu(gemm(gx, gw)) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) out, s = self.RUN_RWKV_5( 1, @@ -1616,7 +1451,7 @@ class RWKV(MyModule): out = out.reshape(T, H * N) out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) out = out.to(dtype=x.dtype) * g - out = gemm(out, ow) + out = matmul(out, ow, omx, orx, omy, ory) return x + out, xx[-1, :], s @@ -1703,13 +1538,9 @@ class RWKV(MyModule): "RWKV_CUDA_ON" ] == "1" and "cuda" in str(dev) if cuda_applicable: - ATT = ( - self.cuda_att_seq - if wtype != torch.uint8 - else self.cuda_att_seq_i8 - ) + ATT = self.cuda_att_seq else: - ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 + ATT = self.att_seq if self.version == 5: ATT = self.att_seq_v5 elif self.version == 5.1: @@ -1718,16 +1549,16 @@ class RWKV(MyModule): ATT = self.att_seq_v5_2 if cuda_applicable: ATT = self.cuda_att_seq_v5_2 - FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 + FFN = self.ffn_seq else: - ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 + ATT = self.att_one if self.version == 5: ATT = self.att_one_v5 elif self.version == 5.1: ATT = self.att_one_v5_1 elif self.version == 5.2: ATT = self.att_one_v5_1 # same as v5.1 - FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 + FFN = self.ffn_one x = x.to(dtype=atype, device=dev) @@ -1872,6 +1703,10 @@ class RWKV(MyModule): rrx, rmy, rry, + gmx, + grx, + gmy, + gry, omx, orx, omy, @@ -1944,7 +1779,7 @@ class RWKV(MyModule): x = x @ w["head.weight"] else: if seq_mode and full_output: - x = self.mm8_seq( + x = mm8_seq( x, w["head.weight"], w["head.weight_mx"], @@ -1953,7 +1788,7 @@ class RWKV(MyModule): w["head.weight_ry"], ) else: - x = self.mm8_one( + x = mm8_one( x, w["head.weight"], w["head.weight_mx"], diff --git a/backend-python/rwkv_pip/rwkv5.pyd b/backend-python/rwkv_pip/rwkv5.pyd index f1aaac8..24d7c1f 100644 Binary files a/backend-python/rwkv_pip/rwkv5.pyd and b/backend-python/rwkv_pip/rwkv5.pyd differ diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py index 36e165d..b09f230 100644 --- a/backend-python/rwkv_pip/utils.py +++ b/backend-python/rwkv_pip/utils.py @@ -81,6 +81,7 @@ class PIPELINE: def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0): probs = F.softmax(logits.float(), dim=-1) top_k = int(top_k) + # 'privateuseone' is the type of custom devices like `torch_directml.device()` if probs.device.type in ["cpu", "privateuseone"]: probs = probs.cpu().numpy() sorted_ids = np.argsort(probs) diff --git a/backend-python/rwkv_pip/wkv_cuda.pyd b/backend-python/rwkv_pip/wkv_cuda.pyd index 2b8bc6f..8e668ed 100644 Binary files a/backend-python/rwkv_pip/wkv_cuda.pyd and b/backend-python/rwkv_pip/wkv_cuda.pyd differ