upgrade to rwkv 0.8.20

This commit is contained in:
josc146 2023-11-03 23:27:14 +08:00
parent 35e92d2aef
commit 1f81a1e5a8
6 changed files with 188 additions and 359 deletions

View File

@ -3,6 +3,8 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#define CUBLAS_CHECK(condition) \
for (cublasStatus_t _cublas_check_status = (condition); \
@ -18,26 +20,13 @@
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
" at " + std::to_string(__LINE__));
cublasHandle_t get_cublas_handle() {
static cublasHandle_t cublas_handle = []() {
cublasHandle_t handle = nullptr;
CUBLAS_CHECK(cublasCreate(&handle));
#if CUDA_VERSION < 11000
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#else
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
return handle;
}();
return cublas_handle;
}
/*
NOTE: blas gemm is column-major by default, but we need row-major output.
The data of row-major, transposed matrix is exactly the same as the
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
*/
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
const auto cuda_data_type = CUDA_R_16F;
const auto cuda_c_data_type =
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
@ -55,7 +44,7 @@ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
const int cublas_lda = m;
const int cublas_ldb = k;
const int cublas_ldc = m;
cublasHandle_t cublas_handle = get_cublas_handle();
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
#if CUDA_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;

View File

@ -1,5 +1,6 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
#include <c10/cuda/CUDAGuard.h>
typedef at::BFloat16 bf16;
typedef at::Half fp16;
typedef float fp32;
@ -9,12 +10,15 @@ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
}
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
}

View File

@ -171,10 +171,86 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
else:
os.environ["RWKV_CUDA_ON"] = "0"
if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
@MyStatic
def torch_mm8_seq(x, w, mx, rx, my, ry):
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
@MyStatic
def torch_mm8_one(x, w, mx, rx, my, ry):
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
if os.environ.get("RWKV_CUDA_ON") == "1":
@MyStatic
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
def mm8_seq(x, w, mx, rx, my, ry):
if w.device.type == "cuda" and x.dtype == torch.float16:
B, N, M = x.shape[0], w.shape[0], w.shape[1]
return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
else:
return torch_mm8_seq(x, w, mx, rx, my, ry)
@MyStatic
def mm8_one(x, w, mx, rx, my, ry):
if w.device.type == "cuda":
N, M = w.shape[0], w.shape[1]
return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
else:
return torch_mm8_one(x, w, mx, rx, my, ry)
else:
@MyStatic
def mm8_seq(x, w, mx, rx, my, ry):
return torch_mm8_seq(x, w, mx, rx, my, ry)
@MyStatic
def mm8_one(x, w, mx, rx, my, ry):
return torch_mm8_one(x, w, mx, rx, my, ry)
def mm8(
x: torch.Tensor,
w: torch.Tensor,
mx: torch.Tensor,
rx: torch.Tensor,
my: torch.Tensor,
ry: torch.Tensor,
):
if len(x.shape) == 1:
return mm8_one(x, w, mx, rx, my, ry)
return mm8_seq(x, w, mx, rx, my, ry)
def matmul(
a,
b,
mx: Optional[torch.Tensor] = None,
rx: Optional[torch.Tensor] = None,
my: Optional[torch.Tensor] = None,
ry: Optional[torch.Tensor] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if output_dtype is None:
output_dtype = a.dtype
if b.dtype in [torch.float16, torch.bfloat16, torch.float32]:
assert a.dtype == b.dtype
return matmul_float(a, b, output_dtype=output_dtype)
elif b.dtype == torch.uint8:
assert mx is not None
assert rx is not None
assert my is not None
assert ry is not None
return mm8(a, b, mx, rx, my, ry).to(output_dtype)
else:
raise ValueError("Unsupported dtype")
if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
if output_dtype is None:
output_dtype = a.dtype
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
@ -203,9 +279,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
else:
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
if output_dtype is None:
output_dtype = a.dtype
def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
return (a @ b).to(output_dtype)
@ -644,42 +718,6 @@ class RWKV(MyModule):
def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u):
return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u)
@MyFunction
def torch_mm8_seq(self, x, w, mx, rx, my, ry):
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
@MyFunction
def torch_mm8_one(self, x, w, mx, rx, my, ry):
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
if os.environ.get("RWKV_CUDA_ON") == "1":
@MyFunction
def mm8_seq(self, x, w, mx, rx, my, ry):
if w.device.type == "cuda" and x.dtype == torch.float16:
B, N, M = x.shape[0], w.shape[0], w.shape[1]
return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
else:
return self.torch_mm8_seq(x, w, mx, rx, my, ry)
@MyFunction
def mm8_one(self, x, w, mx, rx, my, ry):
if w.device.type == "cuda":
N, M = w.shape[0], w.shape[1]
return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
else:
return self.torch_mm8_one(x, w, mx, rx, my, ry)
else:
@MyFunction
def mm8_seq(self, x, w, mx, rx, my, ry):
return self.torch_mm8_seq(x, w, mx, rx, my, ry)
@MyFunction
def mm8_one(self, x, w, mx, rx, my, ry):
return self.torch_mm8_one(x, w, mx, rx, my, ry)
########################################################################################################
@MyFunction
@ -711,43 +749,9 @@ class RWKV(MyModule):
kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw))
vx = torch.square(torch.relu(gemm(kx, kw)))
out = r * gemm(vx, vw)
return x + out, xx
@MyFunction
def ffn_one_i8(
self,
x,
sx,
ln_w,
ln_b,
k_mix,
r_mix,
kw,
vw,
rw,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry))
vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry)))
out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry))
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)))
out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
return x + out, xx
########################################################################################################
@ -782,44 +786,9 @@ class RWKV(MyModule):
kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw))
vx = torch.square(torch.relu(gemm(kx, kw)))
out = r * gemm(vx, vw)
return x + out, xx[-1, :]
@MyFunction
def ffn_seq_i8(
self,
x,
sx,
ln_w,
ln_b,
k_mix,
r_mix,
kw,
vw,
rw,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry)))
out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry))
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)))
out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
return x + out, xx[-1, :]
########################################################################################################
@ -865,9 +834,9 @@ class RWKV(MyModule):
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw))
k = gemm(kx, kw, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32)
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
ww = t_first + k
p = torch.maximum(pp, ww)
@ -879,65 +848,7 @@ class RWKV(MyModule):
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
out = gemm(r * wkv, ow)
return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
@MyFunction
def att_one_i8(
self,
x,
sx,
aa,
bb,
pp,
ln_w,
ln_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry))
k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float()
v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float()
ww = t_first + k
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
ww = t_decay + pp
p = torch.maximum(ww, k)
e1 = torch.exp(ww - p)
e2 = torch.exp(k - p)
out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory)
out = matmul(r * wkv, ow, omx, orx, omy, ory)
return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
########################################################################################################
@ -984,9 +895,9 @@ class RWKV(MyModule):
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw))
k = gemm(kx, kw, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32)
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
T = x.shape[0]
for t in range(T):
@ -1004,72 +915,7 @@ class RWKV(MyModule):
aa = e1 * aa + e2 * vv
bb = e1 * bb + e2
pp = p
out = gemm(r * sx, ow)
return x + out, xx[-1, :], aa, bb, pp
@MyFunction
def att_seq_i8(
self,
x,
sx,
aa,
bb,
pp,
ln_w,
ln_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float()
v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float()
T = x.shape[0]
for t in range(T):
kk = k[t]
vv = v[t]
ww = t_first + kk
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype)
ww = t_decay + pp
p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
aa = e1 * aa + e2 * vv
bb = e1 * bb + e2
pp = p
out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory)
out = matmul(r * sx, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], aa, bb, pp
########################################################################################################
@ -1118,11 +964,11 @@ class RWKV(MyModule):
H = t_decay.shape[0]
S = x.shape[-1] // H
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S)
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1)
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S)
a = gemm(k, v)
a = matmul(k, v)
out = r @ (t_first * a + s)
s = a + t_decay * s
@ -1131,7 +977,7 @@ class RWKV(MyModule):
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
).squeeze(0)
out = out.to(dtype=x.dtype)
out = gemm(out, ow)
out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx, s
@ -1194,14 +1040,22 @@ class RWKV(MyModule):
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
w = w[:, :, T - 1 :].reshape(H, T, T)
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
r = (
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
k = (
gemm(kx, kw, output_dtype=torch.float32)
matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
.transpose(-2, -1)
)
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
v = (
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
out = ((r @ k) * w) @ v + (r @ s) * wb
s = ws * s + (k * wk) @ v
@ -1209,7 +1063,7 @@ class RWKV(MyModule):
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype)
out = gemm(out, ow)
out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], s
@ -1248,6 +1102,10 @@ class RWKV(MyModule):
rrx,
rmy,
rry,
gmx,
grx,
gmy,
gry,
omx,
orx,
omy,
@ -1262,12 +1120,12 @@ class RWKV(MyModule):
H = t_decay.shape[0]
S = x.shape[-1] // H
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
g = F.silu(gemm(gx, gw))
r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S)
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1)
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S)
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
a = gemm(k, v)
a = matmul(k, v)
out = r @ (t_first * a + s)
s = a + t_decay * s
@ -1276,7 +1134,7 @@ class RWKV(MyModule):
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
).squeeze(0)
out = out.to(dtype=x.dtype) * g
out = gemm(out, ow)
out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx, s
@ -1313,6 +1171,10 @@ class RWKV(MyModule):
rrx,
rmy,
rry,
gmx,
grx,
gmy,
gry,
omx,
orx,
omy,
@ -1342,15 +1204,23 @@ class RWKV(MyModule):
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
w = w[:, :, T - 1 :].reshape(H, T, T)
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
r = (
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
k = (
gemm(kx, kw, output_dtype=torch.float32)
matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
.transpose(-2, -1)
)
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
g = F.silu(gemm(gx, gw))
v = (
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
out = ((r @ k) * w) @ v + (r @ s) * wb
s = ws * s + (k * wk) @ v
@ -1358,7 +1228,7 @@ class RWKV(MyModule):
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype) * g
out = gemm(out, ow)
out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], s
@ -1397,6 +1267,10 @@ class RWKV(MyModule):
rrx,
rmy,
rry,
gmx,
grx,
gmy,
gry,
omx,
orx,
omy,
@ -1413,29 +1287,37 @@ class RWKV(MyModule):
S = x.shape[-1] // H
T = x.shape[0]
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
r = (
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
k = (
gemm(kx, kw, output_dtype=torch.float32)
matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
.transpose(-2, -1)
)
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
g = F.silu(gemm(gx, gw))
v = (
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
)
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
out = torch.empty((T, H, S), dtype=r.dtype, device=r.device)
for t in range(T):
rt = r[:, t : t + 1, :]
kt = k[:, :, t : t + 1]
vt = v[:, t : t + 1, :]
at = gemm(kt, vt)
at = matmul(kt, vt)
out[t] = (rt @ (t_first * at + s)).squeeze(1)
s = at + t_decay * s
out = out.reshape(T, H * S)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype) * g
out = gemm(out, ow)
out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], s
@ -1486,63 +1368,12 @@ class RWKV(MyModule):
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw))
k = gemm(kx, kw, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32)
y, aa, bb, pp = cuda_wkv(T, aa.shape[0], t_decay, t_first, k, v, aa, bb, pp)
out = gemm(r * y.to(x.dtype), ow)
return x + out, xx[-1, :], aa, bb, pp
@MyFunction
def cuda_att_seq_i8(
self,
x,
sx,
aa,
bb,
pp,
ln_w,
ln_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
T, C = x.shape
xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry)
v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp)
out = self.mm8_seq(r * y, ow, omx, orx, omy, ory)
out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory)
return x + out, xx[-1, :], aa, bb, pp
# NOTE: decorate with @MyFunction causes JIT error
@ -1578,6 +1409,10 @@ class RWKV(MyModule):
rrx,
rmy,
rry,
gmx,
grx,
gmy,
gry,
omx,
orx,
omy,
@ -1594,10 +1429,10 @@ class RWKV(MyModule):
N = x.shape[-1] // H
T = x.shape[0]
r = gemm(rx, rw, output_dtype=torch.float32)
k = gemm(kx, kw, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32)
g = F.silu(gemm(gx, gw))
r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
out, s = self.RUN_RWKV_5(
1,
@ -1616,7 +1451,7 @@ class RWKV(MyModule):
out = out.reshape(T, H * N)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype) * g
out = gemm(out, ow)
out = matmul(out, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], s
@ -1703,13 +1538,9 @@ class RWKV(MyModule):
"RWKV_CUDA_ON"
] == "1" and "cuda" in str(dev)
if cuda_applicable:
ATT = (
self.cuda_att_seq
if wtype != torch.uint8
else self.cuda_att_seq_i8
)
ATT = self.cuda_att_seq
else:
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
ATT = self.att_seq
if self.version == 5:
ATT = self.att_seq_v5
elif self.version == 5.1:
@ -1718,16 +1549,16 @@ class RWKV(MyModule):
ATT = self.att_seq_v5_2
if cuda_applicable:
ATT = self.cuda_att_seq_v5_2
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
FFN = self.ffn_seq
else:
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
ATT = self.att_one
if self.version == 5:
ATT = self.att_one_v5
elif self.version == 5.1:
ATT = self.att_one_v5_1
elif self.version == 5.2:
ATT = self.att_one_v5_1 # same as v5.1
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
FFN = self.ffn_one
x = x.to(dtype=atype, device=dev)
@ -1872,6 +1703,10 @@ class RWKV(MyModule):
rrx,
rmy,
rry,
gmx,
grx,
gmy,
gry,
omx,
orx,
omy,
@ -1944,7 +1779,7 @@ class RWKV(MyModule):
x = x @ w["head.weight"]
else:
if seq_mode and full_output:
x = self.mm8_seq(
x = mm8_seq(
x,
w["head.weight"],
w["head.weight_mx"],
@ -1953,7 +1788,7 @@ class RWKV(MyModule):
w["head.weight_ry"],
)
else:
x = self.mm8_one(
x = mm8_one(
x,
w["head.weight"],
w["head.weight_mx"],

Binary file not shown.

View File

@ -81,6 +81,7 @@ class PIPELINE:
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
if probs.device.type in ["cpu", "privateuseone"]:
probs = probs.cpu().numpy()
sorted_ids = np.argsort(probs)

Binary file not shown.