upgrade to rwkv 0.8.20
This commit is contained in:
parent
35e92d2aef
commit
1f81a1e5a8
@ -3,6 +3,8 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#define CUBLAS_CHECK(condition) \
|
||||
for (cublasStatus_t _cublas_check_status = (condition); \
|
||||
@ -18,26 +20,13 @@
|
||||
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
||||
" at " + std::to_string(__LINE__));
|
||||
|
||||
cublasHandle_t get_cublas_handle() {
|
||||
static cublasHandle_t cublas_handle = []() {
|
||||
cublasHandle_t handle = nullptr;
|
||||
CUBLAS_CHECK(cublasCreate(&handle));
|
||||
#if CUDA_VERSION < 11000
|
||||
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
|
||||
#else
|
||||
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
#endif // CUDA_VERSION < 11000
|
||||
return handle;
|
||||
}();
|
||||
return cublas_handle;
|
||||
}
|
||||
|
||||
/*
|
||||
NOTE: blas gemm is column-major by default, but we need row-major output.
|
||||
The data of row-major, transposed matrix is exactly the same as the
|
||||
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
||||
*/
|
||||
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
const auto cuda_data_type = CUDA_R_16F;
|
||||
const auto cuda_c_data_type =
|
||||
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
||||
@ -55,7 +44,7 @@ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||
const int cublas_lda = m;
|
||||
const int cublas_ldb = k;
|
||||
const int cublas_ldc = m;
|
||||
cublasHandle_t cublas_handle = get_cublas_handle();
|
||||
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
|
4
backend-python/rwkv_pip/cuda/rwkv5_op.cpp
vendored
4
backend-python/rwkv_pip/cuda/rwkv5_op.cpp
vendored
@ -1,5 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/ATen.h"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
typedef at::BFloat16 bf16;
|
||||
typedef at::Half fp16;
|
||||
typedef float fp32;
|
||||
@ -9,12 +10,15 @@ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *
|
||||
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
||||
|
||||
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||
}
|
||||
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
||||
}
|
||||
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
||||
}
|
||||
|
||||
|
523
backend-python/rwkv_pip/model.py
vendored
523
backend-python/rwkv_pip/model.py
vendored
@ -171,10 +171,86 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||
else:
|
||||
os.environ["RWKV_CUDA_ON"] = "0"
|
||||
|
||||
if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
|
||||
|
||||
@MyStatic
|
||||
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||
def torch_mm8_seq(x, w, mx, rx, my, ry):
|
||||
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
||||
|
||||
|
||||
@MyStatic
|
||||
def torch_mm8_one(x, w, mx, rx, my, ry):
|
||||
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
||||
|
||||
|
||||
if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||
|
||||
@MyStatic
|
||||
def mm8_seq(x, w, mx, rx, my, ry):
|
||||
if w.device.type == "cuda" and x.dtype == torch.float16:
|
||||
B, N, M = x.shape[0], w.shape[0], w.shape[1]
|
||||
return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
|
||||
else:
|
||||
return torch_mm8_seq(x, w, mx, rx, my, ry)
|
||||
|
||||
@MyStatic
|
||||
def mm8_one(x, w, mx, rx, my, ry):
|
||||
if w.device.type == "cuda":
|
||||
N, M = w.shape[0], w.shape[1]
|
||||
return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
|
||||
else:
|
||||
return torch_mm8_one(x, w, mx, rx, my, ry)
|
||||
|
||||
else:
|
||||
|
||||
@MyStatic
|
||||
def mm8_seq(x, w, mx, rx, my, ry):
|
||||
return torch_mm8_seq(x, w, mx, rx, my, ry)
|
||||
|
||||
@MyStatic
|
||||
def mm8_one(x, w, mx, rx, my, ry):
|
||||
return torch_mm8_one(x, w, mx, rx, my, ry)
|
||||
|
||||
|
||||
def mm8(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
mx: torch.Tensor,
|
||||
rx: torch.Tensor,
|
||||
my: torch.Tensor,
|
||||
ry: torch.Tensor,
|
||||
):
|
||||
if len(x.shape) == 1:
|
||||
return mm8_one(x, w, mx, rx, my, ry)
|
||||
return mm8_seq(x, w, mx, rx, my, ry)
|
||||
|
||||
|
||||
def matmul(
|
||||
a,
|
||||
b,
|
||||
mx: Optional[torch.Tensor] = None,
|
||||
rx: Optional[torch.Tensor] = None,
|
||||
my: Optional[torch.Tensor] = None,
|
||||
ry: Optional[torch.Tensor] = None,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
if output_dtype is None:
|
||||
output_dtype = a.dtype
|
||||
if b.dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||
assert a.dtype == b.dtype
|
||||
return matmul_float(a, b, output_dtype=output_dtype)
|
||||
elif b.dtype == torch.uint8:
|
||||
assert mx is not None
|
||||
assert rx is not None
|
||||
assert my is not None
|
||||
assert ry is not None
|
||||
return mm8(a, b, mx, rx, my, ry).to(output_dtype)
|
||||
else:
|
||||
raise ValueError("Unsupported dtype")
|
||||
|
||||
|
||||
if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
|
||||
|
||||
def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||
if output_dtype is None:
|
||||
output_dtype = a.dtype
|
||||
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
|
||||
@ -203,9 +279,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
|
||||
|
||||
else:
|
||||
|
||||
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||
if output_dtype is None:
|
||||
output_dtype = a.dtype
|
||||
def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||
return (a @ b).to(output_dtype)
|
||||
|
||||
|
||||
@ -644,42 +718,6 @@ class RWKV(MyModule):
|
||||
def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u):
|
||||
return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u)
|
||||
|
||||
@MyFunction
|
||||
def torch_mm8_seq(self, x, w, mx, rx, my, ry):
|
||||
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
||||
|
||||
@MyFunction
|
||||
def torch_mm8_one(self, x, w, mx, rx, my, ry):
|
||||
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
||||
|
||||
if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||
|
||||
@MyFunction
|
||||
def mm8_seq(self, x, w, mx, rx, my, ry):
|
||||
if w.device.type == "cuda" and x.dtype == torch.float16:
|
||||
B, N, M = x.shape[0], w.shape[0], w.shape[1]
|
||||
return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
|
||||
else:
|
||||
return self.torch_mm8_seq(x, w, mx, rx, my, ry)
|
||||
|
||||
@MyFunction
|
||||
def mm8_one(self, x, w, mx, rx, my, ry):
|
||||
if w.device.type == "cuda":
|
||||
N, M = w.shape[0], w.shape[1]
|
||||
return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
|
||||
else:
|
||||
return self.torch_mm8_one(x, w, mx, rx, my, ry)
|
||||
|
||||
else:
|
||||
|
||||
@MyFunction
|
||||
def mm8_seq(self, x, w, mx, rx, my, ry):
|
||||
return self.torch_mm8_seq(x, w, mx, rx, my, ry)
|
||||
|
||||
@MyFunction
|
||||
def mm8_one(self, x, w, mx, rx, my, ry):
|
||||
return self.torch_mm8_one(x, w, mx, rx, my, ry)
|
||||
|
||||
########################################################################################################
|
||||
|
||||
@MyFunction
|
||||
@ -711,43 +749,9 @@ class RWKV(MyModule):
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(gemm(rx, rw))
|
||||
vx = torch.square(torch.relu(gemm(kx, kw)))
|
||||
out = r * gemm(vx, vw)
|
||||
return x + out, xx
|
||||
|
||||
@MyFunction
|
||||
def ffn_one_i8(
|
||||
self,
|
||||
x,
|
||||
sx,
|
||||
ln_w,
|
||||
ln_b,
|
||||
k_mix,
|
||||
r_mix,
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
):
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry))
|
||||
vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry)))
|
||||
out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry))
|
||||
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||
vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)))
|
||||
out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
|
||||
return x + out, xx
|
||||
|
||||
########################################################################################################
|
||||
@ -782,44 +786,9 @@ class RWKV(MyModule):
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(gemm(rx, rw))
|
||||
vx = torch.square(torch.relu(gemm(kx, kw)))
|
||||
out = r * gemm(vx, vw)
|
||||
return x + out, xx[-1, :]
|
||||
|
||||
@MyFunction
|
||||
def ffn_seq_i8(
|
||||
self,
|
||||
x,
|
||||
sx,
|
||||
ln_w,
|
||||
ln_b,
|
||||
k_mix,
|
||||
r_mix,
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
):
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
|
||||
vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry)))
|
||||
out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry))
|
||||
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||
vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)))
|
||||
out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
|
||||
return x + out, xx[-1, :]
|
||||
|
||||
########################################################################################################
|
||||
@ -865,9 +834,9 @@ class RWKV(MyModule):
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(gemm(rx, rw))
|
||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
||||
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||
|
||||
ww = t_first + k
|
||||
p = torch.maximum(pp, ww)
|
||||
@ -879,65 +848,7 @@ class RWKV(MyModule):
|
||||
e1 = torch.exp(ww - p)
|
||||
e2 = torch.exp(k - p)
|
||||
|
||||
out = gemm(r * wkv, ow)
|
||||
return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
|
||||
|
||||
@MyFunction
|
||||
def att_one_i8(
|
||||
self,
|
||||
x,
|
||||
sx,
|
||||
aa,
|
||||
bb,
|
||||
pp,
|
||||
ln_w,
|
||||
ln_b,
|
||||
k_mix,
|
||||
v_mix,
|
||||
r_mix,
|
||||
t_decay,
|
||||
t_first,
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
ow,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
ory,
|
||||
):
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry))
|
||||
k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float()
|
||||
v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float()
|
||||
|
||||
ww = t_first + k
|
||||
p = torch.maximum(pp, ww)
|
||||
e1 = torch.exp(pp - p)
|
||||
e2 = torch.exp(ww - p)
|
||||
wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
|
||||
ww = t_decay + pp
|
||||
p = torch.maximum(ww, k)
|
||||
e1 = torch.exp(ww - p)
|
||||
e2 = torch.exp(k - p)
|
||||
|
||||
out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory)
|
||||
out = matmul(r * wkv, ow, omx, orx, omy, ory)
|
||||
return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
|
||||
|
||||
########################################################################################################
|
||||
@ -984,9 +895,9 @@ class RWKV(MyModule):
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(gemm(rx, rw))
|
||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
||||
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||
|
||||
T = x.shape[0]
|
||||
for t in range(T):
|
||||
@ -1004,72 +915,7 @@ class RWKV(MyModule):
|
||||
aa = e1 * aa + e2 * vv
|
||||
bb = e1 * bb + e2
|
||||
pp = p
|
||||
out = gemm(r * sx, ow)
|
||||
return x + out, xx[-1, :], aa, bb, pp
|
||||
|
||||
@MyFunction
|
||||
def att_seq_i8(
|
||||
self,
|
||||
x,
|
||||
sx,
|
||||
aa,
|
||||
bb,
|
||||
pp,
|
||||
ln_w,
|
||||
ln_b,
|
||||
k_mix,
|
||||
v_mix,
|
||||
r_mix,
|
||||
t_decay,
|
||||
t_first,
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
ow,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
ory,
|
||||
):
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
|
||||
k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float()
|
||||
v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float()
|
||||
|
||||
T = x.shape[0]
|
||||
for t in range(T):
|
||||
kk = k[t]
|
||||
vv = v[t]
|
||||
ww = t_first + kk
|
||||
p = torch.maximum(pp, ww)
|
||||
e1 = torch.exp(pp - p)
|
||||
e2 = torch.exp(ww - p)
|
||||
sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype)
|
||||
ww = t_decay + pp
|
||||
p = torch.maximum(ww, kk)
|
||||
e1 = torch.exp(ww - p)
|
||||
e2 = torch.exp(kk - p)
|
||||
aa = e1 * aa + e2 * vv
|
||||
bb = e1 * bb + e2
|
||||
pp = p
|
||||
out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory)
|
||||
out = matmul(r * sx, ow, omx, orx, omy, ory)
|
||||
return x + out, xx[-1, :], aa, bb, pp
|
||||
|
||||
########################################################################################################
|
||||
@ -1118,11 +964,11 @@ class RWKV(MyModule):
|
||||
H = t_decay.shape[0]
|
||||
S = x.shape[-1] // H
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
|
||||
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
|
||||
r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S)
|
||||
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1)
|
||||
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S)
|
||||
|
||||
a = gemm(k, v)
|
||||
a = matmul(k, v)
|
||||
out = r @ (t_first * a + s)
|
||||
s = a + t_decay * s
|
||||
|
||||
@ -1131,7 +977,7 @@ class RWKV(MyModule):
|
||||
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
|
||||
).squeeze(0)
|
||||
out = out.to(dtype=x.dtype)
|
||||
out = gemm(out, ow)
|
||||
out = matmul(out, ow, omx, orx, omy, ory)
|
||||
|
||||
return x + out, xx, s
|
||||
|
||||
@ -1194,14 +1040,22 @@ class RWKV(MyModule):
|
||||
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
|
||||
w = w[:, :, T - 1 :].reshape(H, T, T)
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
r = (
|
||||
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
k = (
|
||||
gemm(kx, kw, output_dtype=torch.float32)
|
||||
matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
.transpose(-2, -1)
|
||||
)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
v = (
|
||||
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
out = ((r @ k) * w) @ v + (r @ s) * wb
|
||||
s = ws * s + (k * wk) @ v
|
||||
@ -1209,7 +1063,7 @@ class RWKV(MyModule):
|
||||
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
|
||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||
out = out.to(dtype=x.dtype)
|
||||
out = gemm(out, ow)
|
||||
out = matmul(out, ow, omx, orx, omy, ory)
|
||||
|
||||
return x + out, xx[-1, :], s
|
||||
|
||||
@ -1248,6 +1102,10 @@ class RWKV(MyModule):
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
gmx,
|
||||
grx,
|
||||
gmy,
|
||||
gry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
@ -1262,12 +1120,12 @@ class RWKV(MyModule):
|
||||
H = t_decay.shape[0]
|
||||
S = x.shape[-1] // H
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
|
||||
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
|
||||
g = F.silu(gemm(gx, gw))
|
||||
r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S)
|
||||
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1)
|
||||
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S)
|
||||
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
|
||||
|
||||
a = gemm(k, v)
|
||||
a = matmul(k, v)
|
||||
out = r @ (t_first * a + s)
|
||||
s = a + t_decay * s
|
||||
|
||||
@ -1276,7 +1134,7 @@ class RWKV(MyModule):
|
||||
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
|
||||
).squeeze(0)
|
||||
out = out.to(dtype=x.dtype) * g
|
||||
out = gemm(out, ow)
|
||||
out = matmul(out, ow, omx, orx, omy, ory)
|
||||
|
||||
return x + out, xx, s
|
||||
|
||||
@ -1313,6 +1171,10 @@ class RWKV(MyModule):
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
gmx,
|
||||
grx,
|
||||
gmy,
|
||||
gry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
@ -1342,15 +1204,23 @@ class RWKV(MyModule):
|
||||
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
|
||||
w = w[:, :, T - 1 :].reshape(H, T, T)
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
r = (
|
||||
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
k = (
|
||||
gemm(kx, kw, output_dtype=torch.float32)
|
||||
matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
.transpose(-2, -1)
|
||||
)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
g = F.silu(gemm(gx, gw))
|
||||
v = (
|
||||
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
|
||||
|
||||
out = ((r @ k) * w) @ v + (r @ s) * wb
|
||||
s = ws * s + (k * wk) @ v
|
||||
@ -1358,7 +1228,7 @@ class RWKV(MyModule):
|
||||
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
|
||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||
out = out.to(dtype=x.dtype) * g
|
||||
out = gemm(out, ow)
|
||||
out = matmul(out, ow, omx, orx, omy, ory)
|
||||
|
||||
return x + out, xx[-1, :], s
|
||||
|
||||
@ -1397,6 +1267,10 @@ class RWKV(MyModule):
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
gmx,
|
||||
grx,
|
||||
gmy,
|
||||
gry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
@ -1413,29 +1287,37 @@ class RWKV(MyModule):
|
||||
S = x.shape[-1] // H
|
||||
T = x.shape[0]
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
r = (
|
||||
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
k = (
|
||||
gemm(kx, kw, output_dtype=torch.float32)
|
||||
matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
.transpose(-2, -1)
|
||||
)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
g = F.silu(gemm(gx, gw))
|
||||
v = (
|
||||
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
|
||||
|
||||
out = torch.empty((T, H, S), dtype=r.dtype, device=r.device)
|
||||
for t in range(T):
|
||||
rt = r[:, t : t + 1, :]
|
||||
kt = k[:, :, t : t + 1]
|
||||
vt = v[:, t : t + 1, :]
|
||||
at = gemm(kt, vt)
|
||||
at = matmul(kt, vt)
|
||||
out[t] = (rt @ (t_first * at + s)).squeeze(1)
|
||||
s = at + t_decay * s
|
||||
|
||||
out = out.reshape(T, H * S)
|
||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||
out = out.to(dtype=x.dtype) * g
|
||||
out = gemm(out, ow)
|
||||
out = matmul(out, ow, omx, orx, omy, ory)
|
||||
|
||||
return x + out, xx[-1, :], s
|
||||
|
||||
@ -1486,63 +1368,12 @@ class RWKV(MyModule):
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(gemm(rx, rw))
|
||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
||||
y, aa, bb, pp = cuda_wkv(T, aa.shape[0], t_decay, t_first, k, v, aa, bb, pp)
|
||||
|
||||
out = gemm(r * y.to(x.dtype), ow)
|
||||
return x + out, xx[-1, :], aa, bb, pp
|
||||
|
||||
@MyFunction
|
||||
def cuda_att_seq_i8(
|
||||
self,
|
||||
x,
|
||||
sx,
|
||||
aa,
|
||||
bb,
|
||||
pp,
|
||||
ln_w,
|
||||
ln_b,
|
||||
k_mix,
|
||||
v_mix,
|
||||
r_mix,
|
||||
t_decay,
|
||||
t_first,
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
ow,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
ory,
|
||||
):
|
||||
T, C = x.shape
|
||||
xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b)
|
||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
|
||||
k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry)
|
||||
v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)
|
||||
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||
y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp)
|
||||
|
||||
out = self.mm8_seq(r * y, ow, omx, orx, omy, ory)
|
||||
out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory)
|
||||
return x + out, xx[-1, :], aa, bb, pp
|
||||
|
||||
# NOTE: decorate with @MyFunction causes JIT error
|
||||
@ -1578,6 +1409,10 @@ class RWKV(MyModule):
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
gmx,
|
||||
grx,
|
||||
gmy,
|
||||
gry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
@ -1594,10 +1429,10 @@ class RWKV(MyModule):
|
||||
N = x.shape[-1] // H
|
||||
T = x.shape[0]
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32)
|
||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
||||
g = F.silu(gemm(gx, gw))
|
||||
r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
|
||||
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
|
||||
|
||||
out, s = self.RUN_RWKV_5(
|
||||
1,
|
||||
@ -1616,7 +1451,7 @@ class RWKV(MyModule):
|
||||
out = out.reshape(T, H * N)
|
||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||
out = out.to(dtype=x.dtype) * g
|
||||
out = gemm(out, ow)
|
||||
out = matmul(out, ow, omx, orx, omy, ory)
|
||||
|
||||
return x + out, xx[-1, :], s
|
||||
|
||||
@ -1703,13 +1538,9 @@ class RWKV(MyModule):
|
||||
"RWKV_CUDA_ON"
|
||||
] == "1" and "cuda" in str(dev)
|
||||
if cuda_applicable:
|
||||
ATT = (
|
||||
self.cuda_att_seq
|
||||
if wtype != torch.uint8
|
||||
else self.cuda_att_seq_i8
|
||||
)
|
||||
ATT = self.cuda_att_seq
|
||||
else:
|
||||
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
|
||||
ATT = self.att_seq
|
||||
if self.version == 5:
|
||||
ATT = self.att_seq_v5
|
||||
elif self.version == 5.1:
|
||||
@ -1718,16 +1549,16 @@ class RWKV(MyModule):
|
||||
ATT = self.att_seq_v5_2
|
||||
if cuda_applicable:
|
||||
ATT = self.cuda_att_seq_v5_2
|
||||
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
|
||||
FFN = self.ffn_seq
|
||||
else:
|
||||
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
|
||||
ATT = self.att_one
|
||||
if self.version == 5:
|
||||
ATT = self.att_one_v5
|
||||
elif self.version == 5.1:
|
||||
ATT = self.att_one_v5_1
|
||||
elif self.version == 5.2:
|
||||
ATT = self.att_one_v5_1 # same as v5.1
|
||||
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
|
||||
FFN = self.ffn_one
|
||||
|
||||
x = x.to(dtype=atype, device=dev)
|
||||
|
||||
@ -1872,6 +1703,10 @@ class RWKV(MyModule):
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
gmx,
|
||||
grx,
|
||||
gmy,
|
||||
gry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
@ -1944,7 +1779,7 @@ class RWKV(MyModule):
|
||||
x = x @ w["head.weight"]
|
||||
else:
|
||||
if seq_mode and full_output:
|
||||
x = self.mm8_seq(
|
||||
x = mm8_seq(
|
||||
x,
|
||||
w["head.weight"],
|
||||
w["head.weight_mx"],
|
||||
@ -1953,7 +1788,7 @@ class RWKV(MyModule):
|
||||
w["head.weight_ry"],
|
||||
)
|
||||
else:
|
||||
x = self.mm8_one(
|
||||
x = mm8_one(
|
||||
x,
|
||||
w["head.weight"],
|
||||
w["head.weight_mx"],
|
||||
|
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
Binary file not shown.
1
backend-python/rwkv_pip/utils.py
vendored
1
backend-python/rwkv_pip/utils.py
vendored
@ -81,6 +81,7 @@ class PIPELINE:
|
||||
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
||||
probs = F.softmax(logits.float(), dim=-1)
|
||||
top_k = int(top_k)
|
||||
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
|
||||
if probs.device.type in ["cpu", "privateuseone"]:
|
||||
probs = probs.cpu().numpy()
|
||||
sorted_ids = np.argsort(probs)
|
||||
|
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
Binary file not shown.
Loading…
Reference in New Issue
Block a user