upgrade to rwkv 0.8.20

This commit is contained in:
josc146 2023-11-03 23:27:14 +08:00
parent 35e92d2aef
commit 1f81a1e5a8
6 changed files with 188 additions and 359 deletions

View File

@ -3,6 +3,8 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#define CUBLAS_CHECK(condition) \ #define CUBLAS_CHECK(condition) \
for (cublasStatus_t _cublas_check_status = (condition); \ for (cublasStatus_t _cublas_check_status = (condition); \
@ -18,26 +20,13 @@
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \ "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
" at " + std::to_string(__LINE__)); " at " + std::to_string(__LINE__));
cublasHandle_t get_cublas_handle() {
static cublasHandle_t cublas_handle = []() {
cublasHandle_t handle = nullptr;
CUBLAS_CHECK(cublasCreate(&handle));
#if CUDA_VERSION < 11000
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#else
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
return handle;
}();
return cublas_handle;
}
/* /*
NOTE: blas gemm is column-major by default, but we need row-major output. NOTE: blas gemm is column-major by default, but we need row-major output.
The data of row-major, transposed matrix is exactly the same as the The data of row-major, transposed matrix is exactly the same as the
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
*/ */
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
const auto cuda_data_type = CUDA_R_16F; const auto cuda_data_type = CUDA_R_16F;
const auto cuda_c_data_type = const auto cuda_c_data_type =
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
@ -55,7 +44,7 @@ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
const int cublas_lda = m; const int cublas_lda = m;
const int cublas_ldb = k; const int cublas_ldb = k;
const int cublas_ldc = m; const int cublas_ldc = m;
cublasHandle_t cublas_handle = get_cublas_handle(); cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;

View File

@ -1,5 +1,6 @@
#include <torch/extension.h> #include <torch/extension.h>
#include "ATen/ATen.h" #include "ATen/ATen.h"
#include <c10/cuda/CUDAGuard.h>
typedef at::BFloat16 bf16; typedef at::BFloat16 bf16;
typedef at::Half fp16; typedef at::Half fp16;
typedef float fp32; typedef float fp32;
@ -9,12 +10,15 @@ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>()); cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
} }
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>()); cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
} }
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>()); cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
} }

View File

@ -171,10 +171,86 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
else: else:
os.environ["RWKV_CUDA_ON"] = "0" os.environ["RWKV_CUDA_ON"] = "0"
if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
@MyStatic
def torch_mm8_seq(x, w, mx, rx, my, ry):
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
@MyStatic
def torch_mm8_one(x, w, mx, rx, my, ry):
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
if os.environ.get("RWKV_CUDA_ON") == "1":
@MyStatic @MyStatic
def gemm(a, b, output_dtype: Optional[torch.dtype] = None): def mm8_seq(x, w, mx, rx, my, ry):
if w.device.type == "cuda" and x.dtype == torch.float16:
B, N, M = x.shape[0], w.shape[0], w.shape[1]
return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
else:
return torch_mm8_seq(x, w, mx, rx, my, ry)
@MyStatic
def mm8_one(x, w, mx, rx, my, ry):
if w.device.type == "cuda":
N, M = w.shape[0], w.shape[1]
return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
else:
return torch_mm8_one(x, w, mx, rx, my, ry)
else:
@MyStatic
def mm8_seq(x, w, mx, rx, my, ry):
return torch_mm8_seq(x, w, mx, rx, my, ry)
@MyStatic
def mm8_one(x, w, mx, rx, my, ry):
return torch_mm8_one(x, w, mx, rx, my, ry)
def mm8(
x: torch.Tensor,
w: torch.Tensor,
mx: torch.Tensor,
rx: torch.Tensor,
my: torch.Tensor,
ry: torch.Tensor,
):
if len(x.shape) == 1:
return mm8_one(x, w, mx, rx, my, ry)
return mm8_seq(x, w, mx, rx, my, ry)
def matmul(
a,
b,
mx: Optional[torch.Tensor] = None,
rx: Optional[torch.Tensor] = None,
my: Optional[torch.Tensor] = None,
ry: Optional[torch.Tensor] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if output_dtype is None:
output_dtype = a.dtype
if b.dtype in [torch.float16, torch.bfloat16, torch.float32]:
assert a.dtype == b.dtype
return matmul_float(a, b, output_dtype=output_dtype)
elif b.dtype == torch.uint8:
assert mx is not None
assert rx is not None
assert my is not None
assert ry is not None
return mm8(a, b, mx, rx, my, ry).to(output_dtype)
else:
raise ValueError("Unsupported dtype")
if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
if output_dtype is None: if output_dtype is None:
output_dtype = a.dtype output_dtype = a.dtype
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda": if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
@ -203,9 +279,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
else: else:
def gemm(a, b, output_dtype: Optional[torch.dtype] = None): def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
if output_dtype is None:
output_dtype = a.dtype
return (a @ b).to(output_dtype) return (a @ b).to(output_dtype)
@ -644,42 +718,6 @@ class RWKV(MyModule):
def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u): def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u):
return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u) return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u)
@MyFunction
def torch_mm8_seq(self, x, w, mx, rx, my, ry):
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
@MyFunction
def torch_mm8_one(self, x, w, mx, rx, my, ry):
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
if os.environ.get("RWKV_CUDA_ON") == "1":
@MyFunction
def mm8_seq(self, x, w, mx, rx, my, ry):
if w.device.type == "cuda" and x.dtype == torch.float16:
B, N, M = x.shape[0], w.shape[0], w.shape[1]
return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
else:
return self.torch_mm8_seq(x, w, mx, rx, my, ry)
@MyFunction
def mm8_one(self, x, w, mx, rx, my, ry):
if w.device.type == "cuda":
N, M = w.shape[0], w.shape[1]
return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
else:
return self.torch_mm8_one(x, w, mx, rx, my, ry)
else:
@MyFunction
def mm8_seq(self, x, w, mx, rx, my, ry):
return self.torch_mm8_seq(x, w, mx, rx, my, ry)
@MyFunction
def mm8_one(self, x, w, mx, rx, my, ry):
return self.torch_mm8_one(x, w, mx, rx, my, ry)
######################################################################################################## ########################################################################################################
@MyFunction @MyFunction
@ -711,43 +749,9 @@ class RWKV(MyModule):
kx = xx * k_mix + sx * (1 - k_mix) kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix) rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw)) r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
vx = torch.square(torch.relu(gemm(kx, kw))) vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)))
out = r * gemm(vx, vw) out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
return x + out, xx
@MyFunction
def ffn_one_i8(
self,
x,
sx,
ln_w,
ln_b,
k_mix,
r_mix,
kw,
vw,
rw,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry))
vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry)))
out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry))
return x + out, xx return x + out, xx
######################################################################################################## ########################################################################################################
@ -782,44 +786,9 @@ class RWKV(MyModule):
kx = xx * k_mix + sx * (1 - k_mix) kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix) rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw)) r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
vx = torch.square(torch.relu(gemm(kx, kw))) vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)))
out = r * gemm(vx, vw) out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
return x + out, xx[-1, :]
@MyFunction
def ffn_seq_i8(
self,
x,
sx,
ln_w,
ln_b,
k_mix,
r_mix,
kw,
vw,
rw,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry)))
out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry))
return x + out, xx[-1, :] return x + out, xx[-1, :]
######################################################################################################## ########################################################################################################
@ -865,9 +834,9 @@ class RWKV(MyModule):
vx = xx * v_mix + sx * (1 - v_mix) vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix) rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw)) r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
k = gemm(kx, kw, output_dtype=torch.float32) k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32) v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
ww = t_first + k ww = t_first + k
p = torch.maximum(pp, ww) p = torch.maximum(pp, ww)
@ -879,65 +848,7 @@ class RWKV(MyModule):
e1 = torch.exp(ww - p) e1 = torch.exp(ww - p)
e2 = torch.exp(k - p) e2 = torch.exp(k - p)
out = gemm(r * wkv, ow) out = matmul(r * wkv, ow, omx, orx, omy, ory)
return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
@MyFunction
def att_one_i8(
self,
x,
sx,
aa,
bb,
pp,
ln_w,
ln_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry))
k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float()
v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float()
ww = t_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
ww = t_decay + pp
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory)
return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
######################################################################################################## ########################################################################################################
@ -984,9 +895,9 @@ class RWKV(MyModule):
vx = xx * v_mix + sx * (1 - v_mix) vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix) rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw)) r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
k = gemm(kx, kw, output_dtype=torch.float32) k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32) v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
T = x.shape[0] T = x.shape[0]
for t in range(T): for t in range(T):
@ -1004,72 +915,7 @@ class RWKV(MyModule):
aa = e1 * aa + e2 * vv aa = e1 * aa + e2 * vv
bb = e1 * bb + e2 bb = e1 * bb + e2
pp = p pp = p
out = gemm(r * sx, ow) out = matmul(r * sx, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], aa, bb, pp
@MyFunction
def att_seq_i8(
self,
x,
sx,
aa,
bb,
pp,
ln_w,
ln_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float()
v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float()
T = x.shape[0]
for t in range(T):
kk = k[t]
vv = v[t]
ww = t_first + kk
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype)
ww = t_decay + pp
p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
aa = e1 * aa + e2 * vv
bb = e1 * bb + e2
pp = p
out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], aa, bb, pp return x + out, xx[-1, :], aa, bb, pp
######################################################################################################## ########################################################################################################
@ -1118,11 +964,11 @@ class RWKV(MyModule):
H = t_decay.shape[0] H = t_decay.shape[0]
S = x.shape[-1] // H S = x.shape[-1] // H
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S) r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S)
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1) k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1)
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S) v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S)
a = gemm(k, v) a = matmul(k, v)
out = r @ (t_first * a + s) out = r @ (t_first * a + s)
s = a + t_decay * s s = a + t_decay * s
@ -1131,7 +977,7 @@ class RWKV(MyModule):
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
).squeeze(0) ).squeeze(0)
out = out.to(dtype=x.dtype) out = out.to(dtype=x.dtype)
out = gemm(out, ow) out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx, s return x + out, xx, s
@ -1194,14 +1040,22 @@ class RWKV(MyModule):
w = w[:, :-T].reshape(-1, T, 2 * T - 1) w = w[:, :-T].reshape(-1, T, 2 * T - 1)
w = w[:, :, T - 1 :].reshape(H, T, T) w = w[:, :, T - 1 :].reshape(H, T, T)
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) r = (
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
k = ( k = (
gemm(kx, kw, output_dtype=torch.float32) matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
.view(T, H, S) .view(T, H, S)
.transpose(0, 1) .transpose(0, 1)
.transpose(-2, -1) .transpose(-2, -1)
) )
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) v = (
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
out = ((r @ k) * w) @ v + (r @ s) * wb out = ((r @ k) * w) @ v + (r @ s) * wb
s = ws * s + (k * wk) @ v s = ws * s + (k * wk) @ v
@ -1209,7 +1063,7 @@ class RWKV(MyModule):
out = out.transpose(0, 1).contiguous().reshape(T, H * S) out = out.transpose(0, 1).contiguous().reshape(T, H * S)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype) out = out.to(dtype=x.dtype)
out = gemm(out, ow) out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], s return x + out, xx[-1, :], s
@ -1248,6 +1102,10 @@ class RWKV(MyModule):
rrx, rrx,
rmy, rmy,
rry, rry,
gmx,
grx,
gmy,
gry,
omx, omx,
orx, orx,
omy, omy,
@ -1262,12 +1120,12 @@ class RWKV(MyModule):
H = t_decay.shape[0] H = t_decay.shape[0]
S = x.shape[-1] // H S = x.shape[-1] // H
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S) r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S)
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1) k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1)
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S) v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S)
g = F.silu(gemm(gx, gw)) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
a = gemm(k, v) a = matmul(k, v)
out = r @ (t_first * a + s) out = r @ (t_first * a + s)
s = a + t_decay * s s = a + t_decay * s
@ -1276,7 +1134,7 @@ class RWKV(MyModule):
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
).squeeze(0) ).squeeze(0)
out = out.to(dtype=x.dtype) * g out = out.to(dtype=x.dtype) * g
out = gemm(out, ow) out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx, s return x + out, xx, s
@ -1313,6 +1171,10 @@ class RWKV(MyModule):
rrx, rrx,
rmy, rmy,
rry, rry,
gmx,
grx,
gmy,
gry,
omx, omx,
orx, orx,
omy, omy,
@ -1342,15 +1204,23 @@ class RWKV(MyModule):
w = w[:, :-T].reshape(-1, T, 2 * T - 1) w = w[:, :-T].reshape(-1, T, 2 * T - 1)
w = w[:, :, T - 1 :].reshape(H, T, T) w = w[:, :, T - 1 :].reshape(H, T, T)
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) r = (
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
k = ( k = (
gemm(kx, kw, output_dtype=torch.float32) matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
.view(T, H, S) .view(T, H, S)
.transpose(0, 1) .transpose(0, 1)
.transpose(-2, -1) .transpose(-2, -1)
) )
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) v = (
g = F.silu(gemm(gx, gw)) matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
out = ((r @ k) * w) @ v + (r @ s) * wb out = ((r @ k) * w) @ v + (r @ s) * wb
s = ws * s + (k * wk) @ v s = ws * s + (k * wk) @ v
@ -1358,7 +1228,7 @@ class RWKV(MyModule):
out = out.transpose(0, 1).contiguous().reshape(T, H * S) out = out.transpose(0, 1).contiguous().reshape(T, H * S)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype) * g out = out.to(dtype=x.dtype) * g
out = gemm(out, ow) out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], s return x + out, xx[-1, :], s
@ -1397,6 +1267,10 @@ class RWKV(MyModule):
rrx, rrx,
rmy, rmy,
rry, rry,
gmx,
grx,
gmy,
gry,
omx, omx,
orx, orx,
omy, omy,
@ -1413,29 +1287,37 @@ class RWKV(MyModule):
S = x.shape[-1] // H S = x.shape[-1] // H
T = x.shape[0] T = x.shape[0]
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) r = (
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
k = ( k = (
gemm(kx, kw, output_dtype=torch.float32) matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
.view(T, H, S) .view(T, H, S)
.transpose(0, 1) .transpose(0, 1)
.transpose(-2, -1) .transpose(-2, -1)
) )
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) v = (
g = F.silu(gemm(gx, gw)) matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
out = torch.empty((T, H, S), dtype=r.dtype, device=r.device) out = torch.empty((T, H, S), dtype=r.dtype, device=r.device)
for t in range(T): for t in range(T):
rt = r[:, t : t + 1, :] rt = r[:, t : t + 1, :]
kt = k[:, :, t : t + 1] kt = k[:, :, t : t + 1]
vt = v[:, t : t + 1, :] vt = v[:, t : t + 1, :]
at = gemm(kt, vt) at = matmul(kt, vt)
out[t] = (rt @ (t_first * at + s)).squeeze(1) out[t] = (rt @ (t_first * at + s)).squeeze(1)
s = at + t_decay * s s = at + t_decay * s
out = out.reshape(T, H * S) out = out.reshape(T, H * S)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype) * g out = out.to(dtype=x.dtype) * g
out = gemm(out, ow) out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], s return x + out, xx[-1, :], s
@ -1486,63 +1368,12 @@ class RWKV(MyModule):
vx = xx * v_mix + sx * (1 - v_mix) vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix) rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw)) r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
k = gemm(kx, kw, output_dtype=torch.float32) k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32) v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
y, aa, bb, pp = cuda_wkv(T, aa.shape[0], t_decay, t_first, k, v, aa, bb, pp)
out = gemm(r * y.to(x.dtype), ow)
return x + out, xx[-1, :], aa, bb, pp
@MyFunction
def cuda_att_seq_i8(
self,
x,
sx,
aa,
bb,
pp,
ln_w,
ln_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
T, C = x.shape
xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry)
v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)
y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp)
out = self.mm8_seq(r * y, ow, omx, orx, omy, ory) out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory)
return x + out, xx[-1, :], aa, bb, pp return x + out, xx[-1, :], aa, bb, pp
# NOTE: decorate with @MyFunction causes JIT error # NOTE: decorate with @MyFunction causes JIT error
@ -1578,6 +1409,10 @@ class RWKV(MyModule):
rrx, rrx,
rmy, rmy,
rry, rry,
gmx,
grx,
gmy,
gry,
omx, omx,
orx, orx,
omy, omy,
@ -1594,10 +1429,10 @@ class RWKV(MyModule):
N = x.shape[-1] // H N = x.shape[-1] // H
T = x.shape[0] T = x.shape[0]
r = gemm(rx, rw, output_dtype=torch.float32) r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
k = gemm(kx, kw, output_dtype=torch.float32) k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32) v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
g = F.silu(gemm(gx, gw)) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
out, s = self.RUN_RWKV_5( out, s = self.RUN_RWKV_5(
1, 1,
@ -1616,7 +1451,7 @@ class RWKV(MyModule):
out = out.reshape(T, H * N) out = out.reshape(T, H * N)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype) * g out = out.to(dtype=x.dtype) * g
out = gemm(out, ow) out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], s return x + out, xx[-1, :], s
@ -1703,13 +1538,9 @@ class RWKV(MyModule):
"RWKV_CUDA_ON" "RWKV_CUDA_ON"
] == "1" and "cuda" in str(dev) ] == "1" and "cuda" in str(dev)
if cuda_applicable: if cuda_applicable:
ATT = ( ATT = self.cuda_att_seq
self.cuda_att_seq
if wtype != torch.uint8
else self.cuda_att_seq_i8
)
else: else:
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 ATT = self.att_seq
if self.version == 5: if self.version == 5:
ATT = self.att_seq_v5 ATT = self.att_seq_v5
elif self.version == 5.1: elif self.version == 5.1:
@ -1718,16 +1549,16 @@ class RWKV(MyModule):
ATT = self.att_seq_v5_2 ATT = self.att_seq_v5_2
if cuda_applicable: if cuda_applicable:
ATT = self.cuda_att_seq_v5_2 ATT = self.cuda_att_seq_v5_2
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 FFN = self.ffn_seq
else: else:
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 ATT = self.att_one
if self.version == 5: if self.version == 5:
ATT = self.att_one_v5 ATT = self.att_one_v5
elif self.version == 5.1: elif self.version == 5.1:
ATT = self.att_one_v5_1 ATT = self.att_one_v5_1
elif self.version == 5.2: elif self.version == 5.2:
ATT = self.att_one_v5_1 # same as v5.1 ATT = self.att_one_v5_1 # same as v5.1
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 FFN = self.ffn_one
x = x.to(dtype=atype, device=dev) x = x.to(dtype=atype, device=dev)
@ -1872,6 +1703,10 @@ class RWKV(MyModule):
rrx, rrx,
rmy, rmy,
rry, rry,
gmx,
grx,
gmy,
gry,
omx, omx,
orx, orx,
omy, omy,
@ -1944,7 +1779,7 @@ class RWKV(MyModule):
x = x @ w["head.weight"] x = x @ w["head.weight"]
else: else:
if seq_mode and full_output: if seq_mode and full_output:
x = self.mm8_seq( x = mm8_seq(
x, x,
w["head.weight"], w["head.weight"],
w["head.weight_mx"], w["head.weight_mx"],
@ -1953,7 +1788,7 @@ class RWKV(MyModule):
w["head.weight_ry"], w["head.weight_ry"],
) )
else: else:
x = self.mm8_one( x = mm8_one(
x, x,
w["head.weight"], w["head.weight"],
w["head.weight_mx"], w["head.weight_mx"],

Binary file not shown.

View File

@ -81,6 +81,7 @@ class PIPELINE:
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0): def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = F.softmax(logits.float(), dim=-1) probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k) top_k = int(top_k)
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
if probs.device.type in ["cpu", "privateuseone"]: if probs.device.type in ["cpu", "privateuseone"]:
probs = probs.cpu().numpy() probs = probs.cpu().numpy()
sorted_ids = np.argsort(probs) sorted_ids = np.argsort(probs)

Binary file not shown.