diff --git a/backend-python/rwkv_pip/beta/cuda/att_one.cu b/backend-python/rwkv_pip/beta/cuda/att_one.cu index f22858d..743fc12 100644 --- a/backend-python/rwkv_pip/beta/cuda/att_one.cu +++ b/backend-python/rwkv_pip/beta/cuda/att_one.cu @@ -88,7 +88,7 @@ struct Mix { using torch::Tensor; -void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); +void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c); Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, Tensor v_mix, Tensor r_mix, Tensor kw, @@ -105,9 +105,9 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, data_ptr(vx), data_ptr(rx)}, x.numel()); - gemm_fp16_cublas(kx, kw, k); - gemm_fp16_cublas(vx, vw, v); - gemm_fp16_cublas(rx, rw, r); + gemm_fp16_cublas_tensor(kx, kw, k); + gemm_fp16_cublas_tensor(vx, vw, v); + gemm_fp16_cublas_tensor(rx, rw, r); at::sigmoid_(r); element_wise(WkvForwardOne{data_ptr(t_first), data_ptr(k), @@ -118,7 +118,7 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, data_ptr(r)}, x.numel()); - gemm_fp16_cublas(r, ow, x_plus_out); + gemm_fp16_cublas_tensor(r, ow, x_plus_out); x_plus_out += x; return xx; } diff --git a/backend-python/rwkv_pip/beta/cuda/att_one_v5.cu b/backend-python/rwkv_pip/beta/cuda/att_one_v5.cu new file mode 100644 index 0000000..98f811e --- /dev/null +++ b/backend-python/rwkv_pip/beta/cuda/att_one_v5.cu @@ -0,0 +1,109 @@ +#include "ATen/ATen.h" +#include +#include +#include + +#include "element_wise.h" +#include "util.h" + +// Equivalent Python code: +// s1 = t_first * a + s +// s2 = a + t_decay * s +struct Fused1 { + const float *t_first; + const float *t_decay; + const float *a; + const float *s; + const int32_t inner_size; + /* out */ float *s1; + /* out */ float *s2; + + __device__ void operator()(int i) const { + const int j = i / inner_size; + s1[i] = t_first[j] * a[i] + s[i]; + s2[i] = a[i] + t_decay[j] * s[i]; + } +}; + +/* + Equivalent Python code: + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) +*/ + +struct Mix { + const half *xx; + const half *sx; + const half *k_mix; + const half *v_mix; + const half *r_mix; + /* out */ half *kx; + /* out */ half *vx; + /* out */ half *rx; + + __device__ void operator()(int i) const { + half xx_ = xx[i]; + half sx_ = sx[i]; + half k_mix_ = k_mix[i]; + half v_mix_ = v_mix[i]; + half r_mix_ = r_mix[i]; + kx[i] = __hadd(__hmul(xx_, k_mix_), + __hmul(sx_, __hsub(__float2half(1), k_mix_))); + vx[i] = __hadd(__hmul(xx_, v_mix_), + __hmul(sx_, __hsub(__float2half(1), v_mix_))); + rx[i] = __hadd(__hmul(xx_, r_mix_), + __hmul(sx_, __hsub(__float2half(1), r_mix_))); + } +}; + +using torch::Tensor; + +void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c); + +Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b, + Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix, + Tensor r_mix, Tensor kw, + /* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, + Tensor rw, + /* imm */ Tensor rx, Tensor ow, Tensor t_first, + /* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v, + /* imm */ Tensor r, /* imm */ Tensor s1, + /* out */ Tensor x_plus_out, /* out */ Tensor s2) { + Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); + element_wise(Mix{data_ptr(xx), data_ptr(sx), + data_ptr(k_mix), data_ptr(v_mix), + data_ptr(r_mix), data_ptr(kx), + data_ptr(vx), data_ptr(rx)}, + x.numel()); + + int H = t_decay.size(0); + int S = x.size(-1) / H; + gemm_fp16_cublas_tensor(rx, rw, r); + r = at::reshape(r, {H, 1, S}); + gemm_fp16_cublas_tensor(kx, kw, k); + k = at::reshape(k, {H, S, 1}); + gemm_fp16_cublas_tensor(vx, vw, v); + v = at::reshape(v, {H, 1, S}); + + { + Tensor a = at::matmul(k, v); + + // s1 = t_first * a + s + // s2 = a + t_decay * s + element_wise(Fused1{data_ptr(t_first), data_ptr(t_decay), + data_ptr(a), data_ptr(s), + static_cast(a.size(1) * a.size(2)), + data_ptr(s1), data_ptr(s2)}, + a.numel()); + } + + Tensor out = at::matmul(r, s1); + out = at::flatten(out); + out = at::squeeze(at::group_norm(at::unsqueeze(out, 0), H, lx_w, lx_b), 0); + out = at::_cast_Half(out); + + gemm_fp16_cublas_tensor(out, ow, x_plus_out); + x_plus_out += x; + return xx; +} diff --git a/backend-python/rwkv_pip/beta/cuda/att_seq.cu b/backend-python/rwkv_pip/beta/cuda/att_seq.cu index 4d506a3..c8db033 100644 --- a/backend-python/rwkv_pip/beta/cuda/att_seq.cu +++ b/backend-python/rwkv_pip/beta/cuda/att_seq.cu @@ -8,7 +8,6 @@ using torch::Tensor; -void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); void gemm_fp16_cublas(const void *a, const void *b, void *c, int m, int n, int k, bool output_fp32); diff --git a/backend-python/rwkv_pip/beta/cuda/gemm_fp16_cublas.cpp b/backend-python/rwkv_pip/beta/cuda/gemm_fp16_cublas.cpp index 6ec136d..e1162ad 100644 --- a/backend-python/rwkv_pip/beta/cuda/gemm_fp16_cublas.cpp +++ b/backend-python/rwkv_pip/beta/cuda/gemm_fp16_cublas.cpp @@ -70,11 +70,59 @@ void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m, cuda_c_data_type, cublas_ldc, compute_type, algo)); } -void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { - // comptiable with rwkv one mode, 1-D tensor * 2-D tensor - const int m = a.dense_dim() == 1 ? 1 : a.size(0); - const int n = b.size(1); - const int k = b.size(0); - gemm_fp16_cublas(a.data_ptr(), b.data_ptr(), c.data_ptr(), m, n, k, - c.dtype() == torch::kFloat32); +/* + NOTE: blas gemm is column-major by default, but we need row-major output. + The data of row-major, transposed matrix is exactly the same as the + column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T + */ +void gemm_fp16_cublas_tensor(torch::Tensor a, torch::Tensor b, torch::Tensor c) { + if (a.sizes().size() == 1) { + assert(b.sizes().size() == 2); + a = at::unsqueeze(a, 0); + } + const auto cuda_data_type = CUDA_R_16F; + const auto cuda_c_data_type = + c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; + const auto compute_type = CUDA_R_32F; + const float sp_alpha = 1.f; + // swap a and b, and use CUBLAS_OP_N. see the notes above + std::swap(a, b); + const cublasOperation_t cublas_trans_a = CUBLAS_OP_N; + const cublasOperation_t cublas_trans_b = CUBLAS_OP_N; + // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap, + // negative axis is used because of the existence of batch matmul. + const int m = a.size(-1); + const int k = a.size(-2); + const int n = b.size(-2); + const int cublas_lda = m; + const int cublas_ldb = k; + const int cublas_ldc = m; + cublasHandle_t cublas_handle = get_cublas_handle(); + +#if CUDA_VERSION >= 11000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; +#else + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP; +#endif + const float sp_beta = 0.f; + if (a.sizes().size() == 2 && b.sizes().size() == 2) { + CUBLAS_CHECK(cublasGemmEx( + cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha, + a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type, + cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, + compute_type, algo)); + } else { + // batch matmul + assert(a.sizes().size() == 3 && b.sizes().size() == 3); + + const long long int cublas_stride_a = m * k; + const long long int cublas_stride_b = k * n; + const long long int cublas_stride_c = m * n; + CUBLAS_CHECK(cublasGemmStridedBatchedEx( + cublas_handle, cublas_trans_a, cublas_trans_b, m, + n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda, + cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b, + &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c, + a.size(0), compute_type, algo)); + } } diff --git a/backend-python/rwkv_pip/beta/cuda/wrapper.cpp b/backend-python/rwkv_pip/beta/cuda/wrapper.cpp index ee99bfc..5121499 100644 --- a/backend-python/rwkv_pip/beta/cuda/wrapper.cpp +++ b/backend-python/rwkv_pip/beta/cuda/wrapper.cpp @@ -118,7 +118,9 @@ void mm8_one(int64_t N, int64_t M, using torch::Tensor; -void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); +#ifndef DISABLE_CUBLAS_GEMM +void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c); +#endif Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, Tensor v_mix, Tensor r_mix, Tensor kw, @@ -134,6 +136,16 @@ Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb, Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out); +Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b, + Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix, + Tensor r_mix, Tensor kw, + /* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, + Tensor rw, + /* imm */ Tensor rx, Tensor ow, Tensor t_first, + /* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v, + /* imm */ Tensor r, /* imm */ Tensor s1, + /* out */ Tensor x_plus_out, /* out */ Tensor s2); + Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, /* imm */ Tensor buf, @@ -148,8 +160,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("wkv_forward", &wkv_forward, "wkv forward"); m.def("mm8_seq", &mm8_seq, "mm8 seq"); m.def("mm8_one", &mm8_one, "mm8 one"); - m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas"); + m.def("gemm_fp16_cublas", &gemm_fp16_cublas_tensor, "gemv fp16 cublas"); m.def("att_one", &att_one, "att one"); + m.def("att_one_v5", &att_one_v5, "att one v5"); m.def("att_seq", &att_seq, "att seq"); m.def("ffn_seq", &ffn_seq, "ffn seq"); m.def("ffn_one", &ffn_one, "ffn one"); @@ -159,8 +172,9 @@ TORCH_LIBRARY(rwkv, m) { m.def("wkv_forward", wkv_forward); m.def("mm8_seq", mm8_seq); m.def("mm8_one", mm8_one); - m.def("gemm_fp16_cublas", gemm_fp16_cublas); + m.def("gemm_fp16_cublas", gemm_fp16_cublas_tensor); m.def("att_one", att_one); + m.def("att_one_v5", &att_one_v5); m.def("att_seq", att_seq); m.def("ffn_seq", ffn_seq); m.def("ffn_one", ffn_one); diff --git a/backend-python/rwkv_pip/beta/model.py b/backend-python/rwkv_pip/beta/model.py index 808e456..3cc0061 100644 --- a/backend-python/rwkv_pip/beta/model.py +++ b/backend-python/rwkv_pip/beta/model.py @@ -3,7 +3,7 @@ ######################################################################################################## from typing import Optional -import types, gc, os, time, re +import types, gc, os, time, re, platform import torch from torch.nn import functional as F @@ -91,6 +91,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1": f"{current_path}/cuda/att_one.cu", f"{current_path}/cuda/att_seq.cu", f"{current_path}/cuda/ffn.cu", + f"{current_path}/cuda/att_one_v5.cu", ], verbose=True, extra_cuda_cflags=[ @@ -149,26 +150,40 @@ if os.environ.get("RWKV_CUDA_ON") == "1": torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) return y.to(dtype=x.dtype) +else: + os.environ["RWKV_CUDA_ON"] = "0" + +if os.environ.get("RWKV_CUDA_ON") == "1": + @MyStatic def gemm(a, b, output_dtype: Optional[torch.dtype] = None): if output_dtype is None: output_dtype = a.dtype if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda": - assert len(b.shape) == 2 if len(a.shape) == 1: + assert len(b.shape) == 2 c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device) a = a.unsqueeze(0) else: - c = torch.empty( - (a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device - ) + assert len(a.shape) == len(b.shape) + assert len(a.shape) == 2 or len(a.shape) == 3 + # torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit + if len(a.shape) == 2: + c = torch.empty( + (a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device + ) + else: + c = torch.empty( + (a.shape[0], a.shape[1], b.shape[-1]), + dtype=output_dtype, + device=a.device, + ) torch.ops.rwkv.gemm_fp16_cublas(a, b, c) return c else: return (a @ b).to(output_dtype) else: - os.environ["RWKV_CUDA_ON"] = "0" def gemm(a, b, output_dtype: Optional[torch.dtype] = None): if output_dtype is None: @@ -217,7 +232,7 @@ class RWKV(MyModule): ) # load model to CPU first # it is supported to load a pure meta-tensor state dict (e.g. for quick testing) for k, v in self.w.items(): - if v.is_meta: + if isinstance(v, torch.Tensor) and v.is_meta: # torch.zeros_like(v, device='cpu') doesn't produce an all-zero tensor # if v is a meta tensor self.w[k] = torch.zeros(v.shape, dtype=v.dtype, device="cpu") @@ -247,9 +262,14 @@ class RWKV(MyModule): args.n_embd = w["emb.weight"].shape[1] args.n_layer = 0 keys = list(w.keys()) + self.version = 4 for x in keys: layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 args.n_layer = max(args.n_layer, layer_id + 1) + if "ln_x" in x: + self.version = 5 + if self.version == 5 and "att.time_decay" in x: + args.n_head = w[x].shape[0] ####################### Compute strategy @@ -352,6 +372,20 @@ class RWKV(MyModule): del w["blocks.0.ln0.bias"] print_need_newline = False + + REAL_TIME_FIRST = False + for x in list(w.keys()): + if ".time_faaaa" in x: + REAL_TIME_FIRST = True + if REAL_TIME_FIRST: + w = { + k.replace(".time_faaaa", ".time_first") + if ".time_faaaa" in k + else k: v + for k, v in w.items() + } + self.w = w + keys = list(w.keys()) for x in keys: w[x].requires_grad = False @@ -382,8 +416,19 @@ class RWKV(MyModule): w[x] = w[x].t() if ".time_decay" in x: # need fp32 for this - w[x] = -torch.exp(w[x].float()) + if self.version == 4: + w[x] = -torch.exp(w[x].float()) + elif self.version == 5: + w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1) elif ".time_first" in x: # need fp32 for this + if self.version == 4: + w[x] = w[x].float() + elif self.version == 5: + if REAL_TIME_FIRST: + w[x] = w[x].float().reshape(-1, 1, 1) + else: + w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1) + elif ".ln_x" in x: # need fp32 for group_norm w[x] = w[x].float() else: if (len(w[x].shape) == 2) and ("emb" not in x): @@ -931,6 +976,147 @@ class RWKV(MyModule): ######################################################################################################## + @MyFunction + def att_one_v5( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + H = t_decay.shape[0] + S = x.shape[-1] // H + + r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S) + k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1) + v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S) + + a = gemm(k, v) + out = r @ (t_first * a + s) + s = a + t_decay * s + + out = out.flatten() + out = F.group_norm( + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b + ).squeeze(0) + out = out.to(dtype=x.dtype) + out = gemm(out, ow) + + return x + out, xx, s + + @MyFunction + def att_seq_v5( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + H = t_decay.shape[0] + S = x.shape[-1] // H + T = x.shape[0] + + w = t_decay.reshape(-1, 1) + u = t_first.reshape(-1, 1) + ws = w.pow(T).reshape(H, 1, 1) + ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1) + w = w.repeat(1, T).pow(ind) + wk = w.reshape(H, 1, T) + wb = wk.transpose(-2, -1).flip(1) + w = torch.cat([w[:, 1:], u], dim=1) + w = F.pad(w, (0, T)) + w = torch.tile(w, [T]) + w = w[:, :-T].reshape(-1, T, 2 * T - 1) + w = w[:, :, T - 1 :].reshape(H, T, T) + + r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + k = ( + gemm(kx, kw, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + .transpose(-2, -1) + ) + v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + + out = ((r @ k) * w) @ v + (r @ s) * wb + s = ws * s + (k * wk) @ v + + out = out.transpose(0, 1).contiguous().reshape(T, H * S) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.to(dtype=x.dtype) + out = gemm(out, ow) + + return x + out, xx[-1, :], s + + ######################################################################################################## + if os.environ["RWKV_CUDA_ON"] == "1": @MyFunction @@ -1140,7 +1326,7 @@ class RWKV(MyModule): xx = torch.ops.rwkv.ffn_seq( x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, buf, x_plus_out ) - return x_plus_out, xx[-1:] + return x_plus_out, xx[-1, :] @MyFunction def cuda_att_one_fp16( @@ -1220,6 +1406,86 @@ class RWKV(MyModule): ) return x_plus_out_t, xx, t1_t, t2_t, p_t + @MyFunction + def cuda_att_one_v5_fp16( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + kx = torch.empty_like(x) + vx = torch.empty_like(x) + rx = torch.empty_like(x) + + H = t_decay.shape[0] + S = x.shape[-1] // H + + r = torch.empty((H * S,), dtype=torch.float32, device=x.device) + k = torch.empty((H * S,), dtype=torch.float32, device=x.device) + v = torch.empty((H * S,), dtype=torch.float32, device=x.device) + s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device) + s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device) + x_plus_out = torch.empty_like(x) + + xx = torch.ops.rwkv.att_one_v5( + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + kw, + kx, + vw, + vx, + rw, + rx, + ow, + t_first, + k, + t_decay, + v, + r, + s1, + x_plus_out, + s2, + ) + + return x_plus_out, xx, s2 + @MyFunction def cuda_ffn_one_fp16( self, @@ -1265,34 +1531,63 @@ class RWKV(MyModule): args = self.args if state == None: - state = [None] * args.n_layer * 5 - for i in range( - args.n_layer - ): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx - dd = self.strategy[i] - dev = dd.device - atype = dd.atype - state[i * 5 + 0] = torch.zeros( - args.n_embd, dtype=atype, requires_grad=False, device=dev - ).contiguous() - state[i * 5 + 1] = torch.zeros( - args.n_embd, dtype=torch.float, requires_grad=False, device=dev - ).contiguous() - state[i * 5 + 2] = torch.zeros( - args.n_embd, dtype=torch.float, requires_grad=False, device=dev - ).contiguous() - state[i * 5 + 3] = ( - torch.zeros( + if self.version == 4: + state = [None] * args.n_layer * 5 + for i in range( + args.n_layer + ): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i * 5 + 0] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + state[i * 5 + 1] = torch.zeros( args.n_embd, dtype=torch.float, requires_grad=False, device=dev, ).contiguous() - - 1e30 - ) - state[i * 5 + 4] = torch.zeros( - args.n_embd, dtype=atype, requires_grad=False, device=dev - ).contiguous() + state[i * 5 + 2] = torch.zeros( + args.n_embd, + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + state[i * 5 + 3] = ( + torch.zeros( + args.n_embd, + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + - 1e30 + ) + state[i * 5 + 4] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + elif self.version == 5: + state = [None] * args.n_layer * 3 + for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i * 3 + 0] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + state[i * 3 + 1] = torch.zeros( + ( + args.n_head, + args.n_embd // args.n_head, + args.n_embd // args.n_head, + ), + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + state[i * 3 + 2] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() seq_mode = len(tokens) > 1 @@ -1317,9 +1612,13 @@ class RWKV(MyModule): ATT = self.cuda_att_seq_i8 else: ATT = self.cuda_att_seq_naive + if self.version == 5: + ATT = self.att_seq_v5 else: ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 + if self.version == 5: + ATT = self.att_one_v5 if ( "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1" @@ -1327,6 +1626,8 @@ class RWKV(MyModule): ): ATT = self.cuda_att_one_fp16 FFN = self.cuda_ffn_one_fp16 + if self.version == 5: + ATT = self.cuda_att_one_v5_fp16 x = x.to(dtype=atype, device=dev) @@ -1355,46 +1656,82 @@ class RWKV(MyModule): orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x - ( - x, - state[i * 5 + 0], - state[i * 5 + 1], - state[i * 5 + 2], - state[i * 5 + 3], - ) = ATT( - x, - state[i * 5 + 0], - state[i * 5 + 1], - state[i * 5 + 2], - state[i * 5 + 3], - w[f"{bbb}ln1.weight"], - w[f"{bbb}ln1.bias"], - w[f"{att}time_mix_k"], - w[f"{att}time_mix_v"], - w[f"{att}time_mix_r"], - w[f"{att}time_decay"], - w[f"{att}time_first"], - kw, - vw, - rw, - ow, - kmx, - krx, - kmy, - kry, - vmx, - vrx, - vmy, - vry, - rmx, - rrx, - rmy, - rry, - omx, - orx, - omy, - ory, - ) + if self.version == 4: + ( + x, + state[i * 5 + 0], + state[i * 5 + 1], + state[i * 5 + 2], + state[i * 5 + 3], + ) = ATT( + x, + state[i * 5 + 0], + state[i * 5 + 1], + state[i * 5 + 2], + state[i * 5 + 3], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}time_mix_k"], + w[f"{att}time_mix_v"], + w[f"{att}time_mix_r"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ) + elif self.version == 5: + x, state[i * 3 + 0], state[i * 3 + 1] = ATT( + x, + state[i * 3 + 0], + state[i * 3 + 1], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}ln_x.weight"], + w[f"{att}ln_x.bias"], + w[f"{att}time_mix_k"], + w[f"{att}time_mix_v"], + w[f"{att}time_mix_r"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ) if dd.stream: del kw, vw, rw, ow @@ -1417,9 +1754,13 @@ class RWKV(MyModule): rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x - x, state[i * 5 + 4] = FFN( + if self.version == 4: + offset = i * 5 + 4 + elif self.version == 5: + offset = i * 3 + 2 + x, state[offset] = FFN( x, - state[i * 5 + 4], + state[offset], w[f"{bbb}ln2.weight"], w[f"{bbb}ln2.bias"], w[f"{ffn}time_mix_k"],