diff --git a/backend-python/requirements.txt b/backend-python/requirements.txt index be995a6..5627eb9 100644 Binary files a/backend-python/requirements.txt and b/backend-python/requirements.txt differ diff --git a/backend-python/rwkv_pip/cuda/att_one.cu b/backend-python/rwkv_pip/cuda/att_one.cu deleted file mode 100644 index f22858d..0000000 --- a/backend-python/rwkv_pip/cuda/att_one.cu +++ /dev/null @@ -1,124 +0,0 @@ -#include "ATen/ATen.h" -#include -#include -#include - -#include "element_wise.h" -#include "util.h" - -// Equivalent Python code: -// ww = t_first + k -// p = torch.maximum(pp, ww) -// e1 = torch.exp(pp - p) -// e2 = torch.exp(ww - p) -// wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) -// ww = t_decay + pp -// p = torch.maximum(ww, k) -// e1 = torch.exp(ww - p) -// e2 = torch.exp(k - p) -// t1 = e1 * aa + e2 * v -// t2 = e1 * bb + e2 -// r = r * wkv -// return t1, t2, p, r -struct WkvForwardOne { - const float *t_first; - const float *k; - const float *pp; - const float *aa; - const float *bb; - const float *t_decay; - const float *v; - /* out */ float *t1; - /* out */ float *t2; - /* out */ float *p; - /* in & out */ half *r; - - __device__ void operator()(int i) const { - float ww = t_first[i] + k[i]; - float pp_ = pp[i]; - float p_ = (pp_ > ww) ? pp_ : ww; - float e1 = expf(pp_ - p_); - float e2 = expf(ww - p_); - float aa_ = aa[i]; - float bb_ = bb[i]; - float v_ = v[i]; - r[i] = __hmul(r[i], __float2half(((e1 * aa_ + e2 * v_) / (e1 * bb_ + e2)))); - ww = t_decay[i] + pp_; - float k_ = k[i]; - p_ = (ww > k_) ? ww : k_; - e1 = expf(ww - p_); - e2 = expf(k_ - p_); - t1[i] = e1 * aa_ + e2 * v_; - t2[i] = e1 * bb_ + e2; - p[i] = p_; - } -}; - -/* - Equivalent Python code: - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) -*/ - -struct Mix { - const half *xx; - const half *sx; - const half *k_mix; - const half *v_mix; - const half *r_mix; - /* out */ half *kx; - /* out */ half *vx; - /* out */ half *rx; - - __device__ void operator()(int i) const { - half xx_ = xx[i]; - half sx_ = sx[i]; - half k_mix_ = k_mix[i]; - half v_mix_ = v_mix[i]; - half r_mix_ = r_mix[i]; - kx[i] = __hadd(__hmul(xx_, k_mix_), - __hmul(sx_, __hsub(__float2half(1), k_mix_))); - vx[i] = __hadd(__hmul(xx_, v_mix_), - __hmul(sx_, __hsub(__float2half(1), v_mix_))); - rx[i] = __hadd(__hmul(xx_, r_mix_), - __hmul(sx_, __hsub(__float2half(1), r_mix_))); - } -}; - -using torch::Tensor; - -void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); - -Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, - Tensor v_mix, Tensor r_mix, Tensor kw, - /* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw, - /* imm */ Tensor rx, Tensor ow, Tensor t_first, - /* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb, - Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r, - /* out */ Tensor x_plus_out, /* out */ Tensor t1, - /* out */ Tensor t2, /* out */ Tensor p) { - Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); - element_wise(Mix{data_ptr(xx), data_ptr(sx), - data_ptr(k_mix), data_ptr(v_mix), - data_ptr(r_mix), data_ptr(kx), - data_ptr(vx), data_ptr(rx)}, - x.numel()); - - gemm_fp16_cublas(kx, kw, k); - gemm_fp16_cublas(vx, vw, v); - gemm_fp16_cublas(rx, rw, r); - at::sigmoid_(r); - - element_wise(WkvForwardOne{data_ptr(t_first), data_ptr(k), - data_ptr(pp), data_ptr(aa), - data_ptr(bb), data_ptr(t_decay), - data_ptr(v), data_ptr(t1), - data_ptr(t2), data_ptr(p), - data_ptr(r)}, - x.numel()); - - gemm_fp16_cublas(r, ow, x_plus_out); - x_plus_out += x; - return xx; -} diff --git a/backend-python/rwkv_pip/cuda/att_seq.cu b/backend-python/rwkv_pip/cuda/att_seq.cu deleted file mode 100644 index 4d506a3..0000000 --- a/backend-python/rwkv_pip/cuda/att_seq.cu +++ /dev/null @@ -1,179 +0,0 @@ -#include "ATen/ATen.h" -#include -#include -#include - -#include "util.h" -#include "element_wise.h" - -using torch::Tensor; - -void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); -void gemm_fp16_cublas(const void *a, const void *b, void *c, int m, - int n, int k, bool output_fp32); - -// based on `kernel_wkv_forward`, fusing more operations -__global__ void kernel_wkv_forward_new( - const int B, const int T, const int C, const float *__restrict__ const _w, - const float *__restrict__ const _u, const float *__restrict__ const _k, - const float *__restrict__ const _v, const half *__restrict__ const r, - half *__restrict__ const _y, float *__restrict__ const _aa, - float *__restrict__ const _bb, float *__restrict__ const _pp) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - const int _b = idx / C; - const int _c = idx % C; - const int _offset = _b * T * C + _c; - const int _state_offset = _b * C + _c; - - float u = _u[_c]; - float w = _w[_c]; - const float *__restrict__ const k = _k + _offset; - const float *__restrict__ const v = _v + _offset; - half *__restrict__ const y = _y + _offset; - - float aa = _aa[_state_offset]; - float bb = _bb[_state_offset]; - float pp = _pp[_state_offset]; - for (int i = 0; i < T; i++) { - const int ii = i * C; - const float kk = k[ii]; - const float vv = v[ii]; - float ww = u + kk; - float p = max(pp, ww); - float e1 = exp(pp - p); - float e2 = exp(ww - p); - y[ii] = __float2half((e1 * aa + e2 * vv) / (e1 * bb + e2)); - ww = w + pp; - p = max(ww, kk); - e1 = exp(ww - p); - e2 = exp(kk - p); - aa = e1 * aa + e2 * vv; - bb = e1 * bb + e2; - pp = p; - } - _aa[_state_offset] = aa; - _bb[_state_offset] = bb; - _pp[_state_offset] = pp; -} - -void cuda_wkv_forward_new(int B, int T, int C, float *w, float *u, float *k, - float *v, half *r, half *y, float *aa, float *bb, - float *pp) { - dim3 threadsPerBlock(min(C, 32)); - assert(B * C % threadsPerBlock.x == 0); - dim3 numBlocks(B * C / threadsPerBlock.x); - kernel_wkv_forward_new<<>>(B, T, C, w, u, k, v, r, - y, aa, bb, pp); -} - -__global__ void _att_mix(const half *xx, const half *sx, const half *k_mix, - const half *v_mix, const half *r_mix, - const int outer_size, const int inner_size, half *kx, - half *vx, half *rx) { - for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size; - idx2 += blockDim.x * gridDim.x) { - half k_mix_ = k_mix[idx2]; - half v_mix_ = v_mix[idx2]; - half r_mix_ = r_mix[idx2]; - for (int row = 0; row < outer_size; ++row) { - int idx1 = row * inner_size + idx2; - half xx_ = xx[idx1]; - half sx_ = sx[idx1]; - kx[idx1] = __hadd(__hmul(xx_, k_mix_), - __hmul(sx_, __hsub(__float2half(1), k_mix_))); - vx[idx1] = __hadd(__hmul(xx_, v_mix_), - __hmul(sx_, __hsub(__float2half(1), v_mix_))); - rx[idx1] = __hadd(__hmul(xx_, r_mix_), - __hmul(sx_, __hsub(__float2half(1), r_mix_))); - } - } -} - -void att_mix(const half *xx, const half *sx, const half *k_mix, - const half *v_mix, const half *r_mix, const int outer_size, - const int inner_size, half *kx, half *vx, half *rx) { - // 256 is good enough on most GPUs - const int32_t BLOCK_SIZE = 256; - assert(inner_size % BLOCK_SIZE == 0); - _att_mix<<>>( - xx, sx, k_mix, v_mix, r_mix, outer_size, inner_size, kx, vx, rx); -} - -struct InplaceSigmoid { - __device__ __forceinline__ half operator()(int i) const { - ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i])))); - } - half *ptr; -}; - -struct InplaceMul { - __device__ __forceinline__ half operator()(int i) const { - y[i] = __hmul(x[i], y[i]); - } - half *y; - half *x; -}; - -/* - Equivalent Python code: - - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(gemm(rx, rw)) - k = gemm(kx, kw, output_dtype=torch.float32) - v = gemm(vx, vw, output_dtype=torch.float32) - - T = x.shape[0] - for t in range(T): - kk = k[t] - vv = v[t] - ww = t_first + kk - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) - ww = t_decay + pp - p = torch.maximum(ww, kk) - e1 = torch.exp(ww - p) - e2 = torch.exp(kk - p) - aa = e1 * aa + e2 * vv - bb = e1 * bb + e2 - pp = p - out = gemm(r * sx, ow) - return x + out, xx[-1,:], aa, bb, pp -*/ -Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, - Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, - Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb, - Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out) { - Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); - sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0); - char* buf_ptr = (char*)buf.data_ptr(); - half* kx = (half*)buf_ptr; - half* vx = kx + x.numel(); - half* rx = vx + x.numel(); - half* wkv_y = rx + x.numel(); - att_mix(data_ptr(xx), data_ptr(sx), data_ptr(k_mix), - data_ptr(v_mix), data_ptr(r_mix), xx.size(0), xx.size(1), - kx, vx, rx); - float* k = reinterpret_cast(wkv_y + x.numel()); - float* v = k + x.size(0) * kw.size(1); - half* r = reinterpret_cast(v + x.size(0) * vw.size(1)); - - gemm_fp16_cublas(kx, kw.data_ptr(), k, x.size(0), kw.size(1), kw.size(0), true); - gemm_fp16_cublas(vx, vw.data_ptr(), v, x.size(0), vw.size(1), vw.size(0), true); - gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), rw.size(0), false); - element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1)); - cuda_wkv_forward_new(1, x.size(0), x.size(1), data_ptr(t_decay), - data_ptr(t_first), k, v, r, - wkv_y, data_ptr(aa), - data_ptr(bb), data_ptr(pp)); - element_wise(InplaceMul{wkv_y, r}, x.numel()); - gemm_fp16_cublas(wkv_y, ow.data_ptr(), x_plus_out.data_ptr(), x.size(0), ow.size(1), ow.size(0), false); - x_plus_out += x; - return xx; -} diff --git a/backend-python/rwkv_pip/cuda/element_wise.h b/backend-python/rwkv_pip/cuda/element_wise.h deleted file mode 100644 index eedc2f9..0000000 --- a/backend-python/rwkv_pip/cuda/element_wise.h +++ /dev/null @@ -1,21 +0,0 @@ -#include -#include -#include - -template __global__ void _element_wise(Func func, int n) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; - i += blockDim.x * gridDim.x) { - func(i); - } -} - -// NOTE: packed data type (e.g. float4) is a overkill for current sizes -// (4096 in 7B model and 768 in 0.1B model), -// and is not faster than the plain float version. -template -void element_wise(Func func, int n) { - // 256 is good enough on most GPUs - const int32_t BLOCK_SIZE = 256; - assert(n % BLOCK_SIZE == 0); - _element_wise<<>>(func, n); -} diff --git a/backend-python/rwkv_pip/cuda/ffn.cu b/backend-python/rwkv_pip/cuda/ffn.cu deleted file mode 100644 index c1c2c80..0000000 --- a/backend-python/rwkv_pip/cuda/ffn.cu +++ /dev/null @@ -1,165 +0,0 @@ -#include "ATen/ATen.h" -#include -#include -#include - -#include "element_wise.h" -#include "util.h" - -using torch::Tensor; - -void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m, - int ori_n, int ori_k, bool output_fp32); - -__global__ void _ffn_seq_mix(const half *xx, const half *sx, const half *k_mix, - const half *r_mix, const int outer_size, - const int inner_size, half *kx, half *rx) { - for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size; - idx2 += blockDim.x * gridDim.x) { - half k_mix_ = k_mix[idx2]; - half r_mix_ = r_mix[idx2]; - for (int row = 0; row < outer_size; ++row) { - int idx1 = row * inner_size + idx2; - half xx_ = xx[idx1]; - half sx_ = sx[idx1]; - kx[idx1] = __hadd(__hmul(xx_, k_mix_), - __hmul(sx_, __hsub(__float2half(1), k_mix_))); - rx[idx1] = __hadd(__hmul(xx_, r_mix_), - __hmul(sx_, __hsub(__float2half(1), r_mix_))); - } - } -} - -void ffn_seq_mix(const half *xx, const half *sx, const half *k_mix, - const half *r_mix, const int outer_size, const int inner_size, - half *kx, half *rx) { - // 256 is good enough on most GPUs - const int32_t BLOCK_SIZE = 256; - assert(inner_size % BLOCK_SIZE == 0); - _ffn_seq_mix<<>>( - xx, sx, k_mix, r_mix, outer_size, inner_size, kx, rx); -} - -struct InplaceSigmoid { - __device__ __forceinline__ void operator()(int i) const { - ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i])))); - } - half *ptr; -}; - -struct InplaceReLUAndSquare { - __device__ __forceinline__ void operator()(int i) const { - // __hmax is not defined in old cuda - if (__hgt(ptr[i], __float2half(0))) { - ptr[i] = __hmul(ptr[i], ptr[i]); - } else { - ptr[i] = __float2half(0); - } - } - half *ptr; -}; - -struct InplaceFma { - __device__ __forceinline__ void operator()(int i) const { - a[i] = __hfma(a[i], b[i], c[i]); - } - half *a; - const half *b; - const half *c; -}; - -/* - Equivalent Python code: - - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - kx = xx * k_mix + sx * (1 - k_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(gemm(rx, rw)) - vx = torch.square(torch.relu(gemm(kx, kw))) - out = r * gemm(vx, vw) - return x + out, xx[-1,:] -*/ -Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, - Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, - /* imm */ Tensor buf, - /* out */ Tensor x_plus_out) { - Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); - sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0); - char *buf_ptr = (char *)buf.data_ptr(); - half *kx = (half *)buf_ptr; - half *rx = kx + x.numel(); - half *vx = rx + x.numel(); - half *r = vx + x.size(0) * kw.size(1); - ffn_seq_mix(data_ptr(xx), data_ptr(sx), data_ptr(k_mix), - data_ptr(r_mix), xx.size(0), xx.size(1), kx, rx); - - gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), x.size(1), - false); - element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1)); - gemm_fp16_cublas(kx, kw.data_ptr(), vx, x.size(0), kw.size(1), x.size(1), - false); - element_wise(InplaceReLUAndSquare{vx}, x.size(0) * kw.size(1)); - gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), x.size(0), - vw.size(1), vw.size(0), false); - element_wise(InplaceFma{data_ptr(x_plus_out), r, data_ptr(x)}, - x_plus_out.numel()); - return xx; -} - -struct FfnOneMix { - __device__ __forceinline__ void operator()(int idx) { - half k_mix_ = k_mix[idx]; - half r_mix_ = r_mix[idx]; - half xx_ = xx[idx]; - half sx_ = sx[idx]; - kx[idx] = __hadd(__hmul(xx_, k_mix_), - __hmul(sx_, __hsub(__float2half(1), k_mix_))); - rx[idx] = __hadd(__hmul(xx_, r_mix_), - __hmul(sx_, __hsub(__float2half(1), r_mix_))); - } - half *k_mix; - half *r_mix; - half *xx; - half *sx; - half *kx; - half *rx; -}; - -/* - Equivalent Python code: - - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - kx = xx * k_mix + sx * (1 - k_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(gemm(rx, rw)) - vx = torch.square(torch.relu(gemm(kx, kw))) - out = r * gemm(vx, vw) - return x + out, xx -*/ -Tensor ffn_one(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, - Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, - /* imm */ Tensor buf, - /* out */ Tensor x_plus_out) { - Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); - char *buf_ptr = (char *)buf.data_ptr(); - half *kx = (half *)buf_ptr; - half *rx = kx + x.numel(); - half *vx = rx + x.numel(); - half *r = vx + x.size(0) * kw.size(1); - element_wise(FfnOneMix{data_ptr(k_mix), data_ptr(r_mix), - data_ptr(xx), data_ptr(sx), kx, rx}, - x.numel()); - // vector * matrix, so m = 1 - gemm_fp16_cublas(rx, rw.data_ptr(), r, 1, rw.size(1), rw.size(0), false); - element_wise(InplaceSigmoid{r}, rw.size(1)); - gemm_fp16_cublas(kx, kw.data_ptr(), vx, 1, kw.size(1), kw.size(0), false); - element_wise(InplaceReLUAndSquare{vx}, kw.size(1)); - gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), 1, vw.size(1), - vw.size(0), false); - element_wise(InplaceFma{data_ptr(x_plus_out), r, data_ptr(x)}, - x_plus_out.numel()); - return xx; -} diff --git a/backend-python/rwkv_pip/cuda/util.h b/backend-python/rwkv_pip/cuda/util.h deleted file mode 100644 index f00af22..0000000 --- a/backend-python/rwkv_pip/cuda/util.h +++ /dev/null @@ -1,7 +0,0 @@ -#include "ATen/ATen.h" -#include - -template T *data_ptr(torch::Tensor x) { return x.data_ptr(); } -template <> inline half *data_ptr(torch::Tensor x) { - return reinterpret_cast(x.data_ptr()); -} diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py index 81469b9..050bb98 100644 --- a/backend-python/rwkv_pip/model.py +++ b/backend-python/rwkv_pip/model.py @@ -220,7 +220,7 @@ class RWKV(MyModule): else: prxxx = lambda *args, **kwargs: None - STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" + STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" if not re.match(STRATEGY_REGEX, strategy): raise ValueError( "Invalid strategy. Please read https://pypi.org/project/rwkv/" @@ -372,6 +372,10 @@ class RWKV(MyModule): strategy[n].atype = s[i][1][0] strategy[n].wtype = s[i][1][1] strategy[n].stream = False + if strategy[n].device == "dml": + import torch_directml + + strategy[n].device = torch_directml.device() if i == stream_i and n >= (plan[i] - stream_count): strategy[n].stream = True break @@ -577,10 +581,7 @@ class RWKV(MyModule): prxxx(f"Converted and saved. Now this will exit.") exit(0) - if self.version == 5.2: - assert ( - os.environ["RWKV_CUDA_ON"] == "1" - ), "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon)" + if self.version == 5.2 and os.environ["RWKV_CUDA_ON"] == "1": HEAD_SIZE = args.n_att // args.n_head if LoadPreCompileLibrary("rwkv5") is True: rwkv5 = torch.ops.rwkv5 @@ -1363,6 +1364,7 @@ class RWKV(MyModule): ######################################################################################################## + @MyFunction def att_seq_v5_2( self, x, @@ -1408,29 +1410,29 @@ class RWKV(MyModule): gx = xx * g_mix + sx * (1 - g_mix) H = t_decay.shape[0] - N = x.shape[-1] // H + S = x.shape[-1] // H T = x.shape[0] - r = gemm(rx, rw, output_dtype=torch.float32) - k = gemm(kx, kw, output_dtype=torch.float32) - v = gemm(vx, vw, output_dtype=torch.float32) + r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + k = ( + gemm(kx, kw, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + .transpose(-2, -1) + ) + v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) g = F.silu(gemm(gx, gw)) - out, s = self.RUN_RWKV_5( - 1, - T, - self.args.n_att, - H, - s.transpose(-1, -2).contiguous(), - r, - k, - v, - w=t_decay, - u=t_first, - ) - s = s.transpose(-1, -2) + out = torch.empty((T, H, S), dtype=r.dtype, device=r.device) + for t in range(T): + rt = r[:, t : t + 1, :] + kt = k[:, :, t : t + 1] + vt = v[:, t : t + 1, :] + at = gemm(kt, vt) + out[t] = (rt @ (t_first * at + s)).squeeze(1) + s = at + t_decay * s - out = out.reshape(T, H * N) + out = out.reshape(T, H * S) out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) out = out.to(dtype=x.dtype) * g out = gemm(out, ow) @@ -1543,6 +1545,81 @@ class RWKV(MyModule): out = self.mm8_seq(r * y, ow, omx, orx, omy, ory) return x + out, xx[-1, :], aa, bb, pp + # NOTE: decorate with @MyFunction causes JIT error + def cuda_att_seq_v5_2( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + gx = xx * g_mix + sx * (1 - g_mix) + + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + r = gemm(rx, rw, output_dtype=torch.float32) + k = gemm(kx, kw, output_dtype=torch.float32) + v = gemm(vx, vw, output_dtype=torch.float32) + g = F.silu(gemm(gx, gw)) + + out, s = self.RUN_RWKV_5( + 1, + T, + self.args.n_att, + H, + s.transpose(-1, -2).contiguous(), + r, + k, + v, + w=t_decay, + u=t_first, + ) + s = s.transpose(-1, -2) + + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.to(dtype=x.dtype) * g + out = gemm(out, ow) + + return x + out, xx[-1, :], s + ######################################################################################################## def forward(self, tokens, state, full_output=False): @@ -1622,7 +1699,10 @@ class RWKV(MyModule): atype = dd.atype wtype = dd.wtype if seq_mode: - if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1": + cuda_applicable = os.environ[ + "RWKV_CUDA_ON" + ] == "1" and "cuda" in str(dev) + if cuda_applicable: ATT = ( self.cuda_att_seq if wtype != torch.uint8 @@ -1636,6 +1716,8 @@ class RWKV(MyModule): ATT = self.att_seq_v5_1 elif self.version == 5.2: ATT = self.att_seq_v5_2 + if cuda_applicable: + ATT = self.cuda_att_seq_v5_2 FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 else: ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 diff --git a/frontend/src/_locales/ja/main.json b/frontend/src/_locales/ja/main.json index 46e85bd..5f7793d 100644 --- a/frontend/src/_locales/ja/main.json +++ b/frontend/src/_locales/ja/main.json @@ -254,6 +254,5 @@ "User Name": "ユーザー名", "Assistant Name": "アシスタント名", "Insert default system prompt at the beginning": "最初にデフォルトのシステムプロンプトを挿入", - "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon).": "カスタムCUDAカーネルを有効にしてください。最新のRWKV-5ではos.environ['RWKV_CUDA_ON'] == '1'が必要です(近日中に修正します)。", "Format Content": "内容フォーマットの規格化" } \ No newline at end of file diff --git a/frontend/src/_locales/zh-hans/main.json b/frontend/src/_locales/zh-hans/main.json index 139c2b9..b54f233 100644 --- a/frontend/src/_locales/zh-hans/main.json +++ b/frontend/src/_locales/zh-hans/main.json @@ -254,6 +254,5 @@ "User Name": "用户名称", "Assistant Name": "AI名称", "Insert default system prompt at the beginning": "在开头自动插入默认系统提示", - "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon).": "请启用自定义CUDA算子。最新的RWKV-5需要os.environ['RWKV_CUDA_ON'] == '1' (未来会修复)", "Format Content": "规范格式" } \ No newline at end of file diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx index c007ebc..4ade47f 100644 --- a/frontend/src/components/RunButton.tsx +++ b/frontend/src/components/RunButton.tsx @@ -212,7 +212,6 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean 'no NVIDIA driver': 'Found no NVIDIA driver, please install the latest driver.', 'CUDA out of memory': 'VRAM is not enough, please reduce stored layers or use a lower precision in Configs page.', 'Ninja is required to load C++ extensions': 'Failed to enable custom CUDA kernel, ninja is required to load C++ extensions. You may be using the CPU version of PyTorch, please reinstall PyTorch with CUDA. Or if you are using a custom Python interpreter, you must compile the CUDA kernel by yourself or disable Custom CUDA kernel acceleration.', - 'Please Enable Custom CUDA Kernel': 'Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ[\'RWKV_CUDA_ON\'] == \'1\' (will fix soon).' }; const matchedError = Object.entries(errorsMap).find(([key, _]) => error.includes(key)); const message = matchedError ? t(matchedError[1]) : error;