upgrade to rwkv 0.8.20
This commit is contained in:
parent
35e92d2aef
commit
1f81a1e5a8
@ -3,6 +3,8 @@
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
#define CUBLAS_CHECK(condition) \
|
#define CUBLAS_CHECK(condition) \
|
||||||
for (cublasStatus_t _cublas_check_status = (condition); \
|
for (cublasStatus_t _cublas_check_status = (condition); \
|
||||||
@ -18,26 +20,13 @@
|
|||||||
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
||||||
" at " + std::to_string(__LINE__));
|
" at " + std::to_string(__LINE__));
|
||||||
|
|
||||||
cublasHandle_t get_cublas_handle() {
|
|
||||||
static cublasHandle_t cublas_handle = []() {
|
|
||||||
cublasHandle_t handle = nullptr;
|
|
||||||
CUBLAS_CHECK(cublasCreate(&handle));
|
|
||||||
#if CUDA_VERSION < 11000
|
|
||||||
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
|
|
||||||
#else
|
|
||||||
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
|
||||||
#endif // CUDA_VERSION < 11000
|
|
||||||
return handle;
|
|
||||||
}();
|
|
||||||
return cublas_handle;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
NOTE: blas gemm is column-major by default, but we need row-major output.
|
NOTE: blas gemm is column-major by default, but we need row-major output.
|
||||||
The data of row-major, transposed matrix is exactly the same as the
|
The data of row-major, transposed matrix is exactly the same as the
|
||||||
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
||||||
*/
|
*/
|
||||||
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||||
const auto cuda_data_type = CUDA_R_16F;
|
const auto cuda_data_type = CUDA_R_16F;
|
||||||
const auto cuda_c_data_type =
|
const auto cuda_c_data_type =
|
||||||
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
||||||
@ -55,7 +44,7 @@ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
|||||||
const int cublas_lda = m;
|
const int cublas_lda = m;
|
||||||
const int cublas_ldb = k;
|
const int cublas_ldb = k;
|
||||||
const int cublas_ldc = m;
|
const int cublas_ldc = m;
|
||||||
cublasHandle_t cublas_handle = get_cublas_handle();
|
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
|
||||||
|
|
||||||
#if CUDA_VERSION >= 11000
|
#if CUDA_VERSION >= 11000
|
||||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||||
|
4
backend-python/rwkv_pip/cuda/rwkv5_op.cpp
vendored
4
backend-python/rwkv_pip/cuda/rwkv5_op.cpp
vendored
@ -1,5 +1,6 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include "ATen/ATen.h"
|
#include "ATen/ATen.h"
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
typedef at::BFloat16 bf16;
|
typedef at::BFloat16 bf16;
|
||||||
typedef at::Half fp16;
|
typedef at::Half fp16;
|
||||||
typedef float fp32;
|
typedef float fp32;
|
||||||
@ -9,12 +10,15 @@ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *
|
|||||||
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
||||||
|
|
||||||
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||||
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||||
}
|
}
|
||||||
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||||
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
||||||
}
|
}
|
||||||
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
||||||
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
523
backend-python/rwkv_pip/model.py
vendored
523
backend-python/rwkv_pip/model.py
vendored
@ -171,10 +171,86 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
|
|||||||
else:
|
else:
|
||||||
os.environ["RWKV_CUDA_ON"] = "0"
|
os.environ["RWKV_CUDA_ON"] = "0"
|
||||||
|
|
||||||
if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
|
|
||||||
|
@MyStatic
|
||||||
|
def torch_mm8_seq(x, w, mx, rx, my, ry):
|
||||||
|
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
||||||
|
|
||||||
|
|
||||||
|
@MyStatic
|
||||||
|
def torch_mm8_one(x, w, mx, rx, my, ry):
|
||||||
|
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
||||||
|
|
||||||
|
|
||||||
|
if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||||
|
|
||||||
@MyStatic
|
@MyStatic
|
||||||
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
def mm8_seq(x, w, mx, rx, my, ry):
|
||||||
|
if w.device.type == "cuda" and x.dtype == torch.float16:
|
||||||
|
B, N, M = x.shape[0], w.shape[0], w.shape[1]
|
||||||
|
return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
|
||||||
|
else:
|
||||||
|
return torch_mm8_seq(x, w, mx, rx, my, ry)
|
||||||
|
|
||||||
|
@MyStatic
|
||||||
|
def mm8_one(x, w, mx, rx, my, ry):
|
||||||
|
if w.device.type == "cuda":
|
||||||
|
N, M = w.shape[0], w.shape[1]
|
||||||
|
return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
|
||||||
|
else:
|
||||||
|
return torch_mm8_one(x, w, mx, rx, my, ry)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
@MyStatic
|
||||||
|
def mm8_seq(x, w, mx, rx, my, ry):
|
||||||
|
return torch_mm8_seq(x, w, mx, rx, my, ry)
|
||||||
|
|
||||||
|
@MyStatic
|
||||||
|
def mm8_one(x, w, mx, rx, my, ry):
|
||||||
|
return torch_mm8_one(x, w, mx, rx, my, ry)
|
||||||
|
|
||||||
|
|
||||||
|
def mm8(
|
||||||
|
x: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
mx: torch.Tensor,
|
||||||
|
rx: torch.Tensor,
|
||||||
|
my: torch.Tensor,
|
||||||
|
ry: torch.Tensor,
|
||||||
|
):
|
||||||
|
if len(x.shape) == 1:
|
||||||
|
return mm8_one(x, w, mx, rx, my, ry)
|
||||||
|
return mm8_seq(x, w, mx, rx, my, ry)
|
||||||
|
|
||||||
|
|
||||||
|
def matmul(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
mx: Optional[torch.Tensor] = None,
|
||||||
|
rx: Optional[torch.Tensor] = None,
|
||||||
|
my: Optional[torch.Tensor] = None,
|
||||||
|
ry: Optional[torch.Tensor] = None,
|
||||||
|
output_dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if output_dtype is None:
|
||||||
|
output_dtype = a.dtype
|
||||||
|
if b.dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||||
|
assert a.dtype == b.dtype
|
||||||
|
return matmul_float(a, b, output_dtype=output_dtype)
|
||||||
|
elif b.dtype == torch.uint8:
|
||||||
|
assert mx is not None
|
||||||
|
assert rx is not None
|
||||||
|
assert my is not None
|
||||||
|
assert ry is not None
|
||||||
|
return mm8(a, b, mx, rx, my, ry).to(output_dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported dtype")
|
||||||
|
|
||||||
|
|
||||||
|
if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
|
||||||
|
|
||||||
|
def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||||
if output_dtype is None:
|
if output_dtype is None:
|
||||||
output_dtype = a.dtype
|
output_dtype = a.dtype
|
||||||
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
|
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
|
||||||
@ -203,9 +279,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||||
if output_dtype is None:
|
|
||||||
output_dtype = a.dtype
|
|
||||||
return (a @ b).to(output_dtype)
|
return (a @ b).to(output_dtype)
|
||||||
|
|
||||||
|
|
||||||
@ -644,42 +718,6 @@ class RWKV(MyModule):
|
|||||||
def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u):
|
def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u):
|
||||||
return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u)
|
return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u)
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def torch_mm8_seq(self, x, w, mx, rx, my, ry):
|
|
||||||
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def torch_mm8_one(self, x, w, mx, rx, my, ry):
|
|
||||||
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
|
||||||
|
|
||||||
if os.environ.get("RWKV_CUDA_ON") == "1":
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def mm8_seq(self, x, w, mx, rx, my, ry):
|
|
||||||
if w.device.type == "cuda" and x.dtype == torch.float16:
|
|
||||||
B, N, M = x.shape[0], w.shape[0], w.shape[1]
|
|
||||||
return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
|
|
||||||
else:
|
|
||||||
return self.torch_mm8_seq(x, w, mx, rx, my, ry)
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def mm8_one(self, x, w, mx, rx, my, ry):
|
|
||||||
if w.device.type == "cuda":
|
|
||||||
N, M = w.shape[0], w.shape[1]
|
|
||||||
return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
|
|
||||||
else:
|
|
||||||
return self.torch_mm8_one(x, w, mx, rx, my, ry)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def mm8_seq(self, x, w, mx, rx, my, ry):
|
|
||||||
return self.torch_mm8_seq(x, w, mx, rx, my, ry)
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def mm8_one(self, x, w, mx, rx, my, ry):
|
|
||||||
return self.torch_mm8_one(x, w, mx, rx, my, ry)
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
@MyFunction
|
@MyFunction
|
||||||
@ -711,43 +749,9 @@ class RWKV(MyModule):
|
|||||||
kx = xx * k_mix + sx * (1 - k_mix)
|
kx = xx * k_mix + sx * (1 - k_mix)
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||||||
|
|
||||||
r = torch.sigmoid(gemm(rx, rw))
|
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||||
vx = torch.square(torch.relu(gemm(kx, kw)))
|
vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)))
|
||||||
out = r * gemm(vx, vw)
|
out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
|
||||||
return x + out, xx
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def ffn_one_i8(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
sx,
|
|
||||||
ln_w,
|
|
||||||
ln_b,
|
|
||||||
k_mix,
|
|
||||||
r_mix,
|
|
||||||
kw,
|
|
||||||
vw,
|
|
||||||
rw,
|
|
||||||
kmx,
|
|
||||||
krx,
|
|
||||||
kmy,
|
|
||||||
kry,
|
|
||||||
vmx,
|
|
||||||
vrx,
|
|
||||||
vmy,
|
|
||||||
vry,
|
|
||||||
rmx,
|
|
||||||
rrx,
|
|
||||||
rmy,
|
|
||||||
rry,
|
|
||||||
):
|
|
||||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
|
||||||
kx = xx * k_mix + sx * (1 - k_mix)
|
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
|
||||||
|
|
||||||
r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry))
|
|
||||||
vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry)))
|
|
||||||
out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry))
|
|
||||||
return x + out, xx
|
return x + out, xx
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
@ -782,44 +786,9 @@ class RWKV(MyModule):
|
|||||||
kx = xx * k_mix + sx * (1 - k_mix)
|
kx = xx * k_mix + sx * (1 - k_mix)
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||||||
|
|
||||||
r = torch.sigmoid(gemm(rx, rw))
|
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||||
vx = torch.square(torch.relu(gemm(kx, kw)))
|
vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)))
|
||||||
out = r * gemm(vx, vw)
|
out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
|
||||||
return x + out, xx[-1, :]
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def ffn_seq_i8(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
sx,
|
|
||||||
ln_w,
|
|
||||||
ln_b,
|
|
||||||
k_mix,
|
|
||||||
r_mix,
|
|
||||||
kw,
|
|
||||||
vw,
|
|
||||||
rw,
|
|
||||||
kmx,
|
|
||||||
krx,
|
|
||||||
kmy,
|
|
||||||
kry,
|
|
||||||
vmx,
|
|
||||||
vrx,
|
|
||||||
vmy,
|
|
||||||
vry,
|
|
||||||
rmx,
|
|
||||||
rrx,
|
|
||||||
rmy,
|
|
||||||
rry,
|
|
||||||
):
|
|
||||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
|
||||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
|
|
||||||
kx = xx * k_mix + sx * (1 - k_mix)
|
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
|
||||||
|
|
||||||
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
|
|
||||||
vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry)))
|
|
||||||
out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry))
|
|
||||||
return x + out, xx[-1, :]
|
return x + out, xx[-1, :]
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
@ -865,9 +834,9 @@ class RWKV(MyModule):
|
|||||||
vx = xx * v_mix + sx * (1 - v_mix)
|
vx = xx * v_mix + sx * (1 - v_mix)
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||||||
|
|
||||||
r = torch.sigmoid(gemm(rx, rw))
|
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||||
|
|
||||||
ww = t_first + k
|
ww = t_first + k
|
||||||
p = torch.maximum(pp, ww)
|
p = torch.maximum(pp, ww)
|
||||||
@ -879,65 +848,7 @@ class RWKV(MyModule):
|
|||||||
e1 = torch.exp(ww - p)
|
e1 = torch.exp(ww - p)
|
||||||
e2 = torch.exp(k - p)
|
e2 = torch.exp(k - p)
|
||||||
|
|
||||||
out = gemm(r * wkv, ow)
|
out = matmul(r * wkv, ow, omx, orx, omy, ory)
|
||||||
return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def att_one_i8(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
sx,
|
|
||||||
aa,
|
|
||||||
bb,
|
|
||||||
pp,
|
|
||||||
ln_w,
|
|
||||||
ln_b,
|
|
||||||
k_mix,
|
|
||||||
v_mix,
|
|
||||||
r_mix,
|
|
||||||
t_decay,
|
|
||||||
t_first,
|
|
||||||
kw,
|
|
||||||
vw,
|
|
||||||
rw,
|
|
||||||
ow,
|
|
||||||
kmx,
|
|
||||||
krx,
|
|
||||||
kmy,
|
|
||||||
kry,
|
|
||||||
vmx,
|
|
||||||
vrx,
|
|
||||||
vmy,
|
|
||||||
vry,
|
|
||||||
rmx,
|
|
||||||
rrx,
|
|
||||||
rmy,
|
|
||||||
rry,
|
|
||||||
omx,
|
|
||||||
orx,
|
|
||||||
omy,
|
|
||||||
ory,
|
|
||||||
):
|
|
||||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
|
||||||
kx = xx * k_mix + sx * (1 - k_mix)
|
|
||||||
vx = xx * v_mix + sx * (1 - v_mix)
|
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
|
||||||
|
|
||||||
r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry))
|
|
||||||
k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float()
|
|
||||||
v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float()
|
|
||||||
|
|
||||||
ww = t_first + k
|
|
||||||
p = torch.maximum(pp, ww)
|
|
||||||
e1 = torch.exp(pp - p)
|
|
||||||
e2 = torch.exp(ww - p)
|
|
||||||
wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
|
|
||||||
ww = t_decay + pp
|
|
||||||
p = torch.maximum(ww, k)
|
|
||||||
e1 = torch.exp(ww - p)
|
|
||||||
e2 = torch.exp(k - p)
|
|
||||||
|
|
||||||
out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory)
|
|
||||||
return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
|
return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
@ -984,9 +895,9 @@ class RWKV(MyModule):
|
|||||||
vx = xx * v_mix + sx * (1 - v_mix)
|
vx = xx * v_mix + sx * (1 - v_mix)
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||||||
|
|
||||||
r = torch.sigmoid(gemm(rx, rw))
|
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||||
|
|
||||||
T = x.shape[0]
|
T = x.shape[0]
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
@ -1004,72 +915,7 @@ class RWKV(MyModule):
|
|||||||
aa = e1 * aa + e2 * vv
|
aa = e1 * aa + e2 * vv
|
||||||
bb = e1 * bb + e2
|
bb = e1 * bb + e2
|
||||||
pp = p
|
pp = p
|
||||||
out = gemm(r * sx, ow)
|
out = matmul(r * sx, ow, omx, orx, omy, ory)
|
||||||
return x + out, xx[-1, :], aa, bb, pp
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def att_seq_i8(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
sx,
|
|
||||||
aa,
|
|
||||||
bb,
|
|
||||||
pp,
|
|
||||||
ln_w,
|
|
||||||
ln_b,
|
|
||||||
k_mix,
|
|
||||||
v_mix,
|
|
||||||
r_mix,
|
|
||||||
t_decay,
|
|
||||||
t_first,
|
|
||||||
kw,
|
|
||||||
vw,
|
|
||||||
rw,
|
|
||||||
ow,
|
|
||||||
kmx,
|
|
||||||
krx,
|
|
||||||
kmy,
|
|
||||||
kry,
|
|
||||||
vmx,
|
|
||||||
vrx,
|
|
||||||
vmy,
|
|
||||||
vry,
|
|
||||||
rmx,
|
|
||||||
rrx,
|
|
||||||
rmy,
|
|
||||||
rry,
|
|
||||||
omx,
|
|
||||||
orx,
|
|
||||||
omy,
|
|
||||||
ory,
|
|
||||||
):
|
|
||||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
|
||||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
|
|
||||||
kx = xx * k_mix + sx * (1 - k_mix)
|
|
||||||
vx = xx * v_mix + sx * (1 - v_mix)
|
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
|
||||||
|
|
||||||
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
|
|
||||||
k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float()
|
|
||||||
v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float()
|
|
||||||
|
|
||||||
T = x.shape[0]
|
|
||||||
for t in range(T):
|
|
||||||
kk = k[t]
|
|
||||||
vv = v[t]
|
|
||||||
ww = t_first + kk
|
|
||||||
p = torch.maximum(pp, ww)
|
|
||||||
e1 = torch.exp(pp - p)
|
|
||||||
e2 = torch.exp(ww - p)
|
|
||||||
sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype)
|
|
||||||
ww = t_decay + pp
|
|
||||||
p = torch.maximum(ww, kk)
|
|
||||||
e1 = torch.exp(ww - p)
|
|
||||||
e2 = torch.exp(kk - p)
|
|
||||||
aa = e1 * aa + e2 * vv
|
|
||||||
bb = e1 * bb + e2
|
|
||||||
pp = p
|
|
||||||
out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory)
|
|
||||||
return x + out, xx[-1, :], aa, bb, pp
|
return x + out, xx[-1, :], aa, bb, pp
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
@ -1118,11 +964,11 @@ class RWKV(MyModule):
|
|||||||
H = t_decay.shape[0]
|
H = t_decay.shape[0]
|
||||||
S = x.shape[-1] // H
|
S = x.shape[-1] // H
|
||||||
|
|
||||||
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
|
r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S)
|
||||||
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
|
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1)
|
||||||
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
|
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S)
|
||||||
|
|
||||||
a = gemm(k, v)
|
a = matmul(k, v)
|
||||||
out = r @ (t_first * a + s)
|
out = r @ (t_first * a + s)
|
||||||
s = a + t_decay * s
|
s = a + t_decay * s
|
||||||
|
|
||||||
@ -1131,7 +977,7 @@ class RWKV(MyModule):
|
|||||||
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
|
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
out = out.to(dtype=x.dtype)
|
out = out.to(dtype=x.dtype)
|
||||||
out = gemm(out, ow)
|
out = matmul(out, ow, omx, orx, omy, ory)
|
||||||
|
|
||||||
return x + out, xx, s
|
return x + out, xx, s
|
||||||
|
|
||||||
@ -1194,14 +1040,22 @@ class RWKV(MyModule):
|
|||||||
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
|
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
|
||||||
w = w[:, :, T - 1 :].reshape(H, T, T)
|
w = w[:, :, T - 1 :].reshape(H, T, T)
|
||||||
|
|
||||||
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
r = (
|
||||||
|
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
|
||||||
|
.view(T, H, S)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
k = (
|
k = (
|
||||||
gemm(kx, kw, output_dtype=torch.float32)
|
matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||||
.view(T, H, S)
|
.view(T, H, S)
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
.transpose(-2, -1)
|
.transpose(-2, -1)
|
||||||
)
|
)
|
||||||
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
v = (
|
||||||
|
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||||
|
.view(T, H, S)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
out = ((r @ k) * w) @ v + (r @ s) * wb
|
out = ((r @ k) * w) @ v + (r @ s) * wb
|
||||||
s = ws * s + (k * wk) @ v
|
s = ws * s + (k * wk) @ v
|
||||||
@ -1209,7 +1063,7 @@ class RWKV(MyModule):
|
|||||||
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
|
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
|
||||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||||
out = out.to(dtype=x.dtype)
|
out = out.to(dtype=x.dtype)
|
||||||
out = gemm(out, ow)
|
out = matmul(out, ow, omx, orx, omy, ory)
|
||||||
|
|
||||||
return x + out, xx[-1, :], s
|
return x + out, xx[-1, :], s
|
||||||
|
|
||||||
@ -1248,6 +1102,10 @@ class RWKV(MyModule):
|
|||||||
rrx,
|
rrx,
|
||||||
rmy,
|
rmy,
|
||||||
rry,
|
rry,
|
||||||
|
gmx,
|
||||||
|
grx,
|
||||||
|
gmy,
|
||||||
|
gry,
|
||||||
omx,
|
omx,
|
||||||
orx,
|
orx,
|
||||||
omy,
|
omy,
|
||||||
@ -1262,12 +1120,12 @@ class RWKV(MyModule):
|
|||||||
H = t_decay.shape[0]
|
H = t_decay.shape[0]
|
||||||
S = x.shape[-1] // H
|
S = x.shape[-1] // H
|
||||||
|
|
||||||
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
|
r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S)
|
||||||
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
|
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1)
|
||||||
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
|
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S)
|
||||||
g = F.silu(gemm(gx, gw))
|
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
|
||||||
|
|
||||||
a = gemm(k, v)
|
a = matmul(k, v)
|
||||||
out = r @ (t_first * a + s)
|
out = r @ (t_first * a + s)
|
||||||
s = a + t_decay * s
|
s = a + t_decay * s
|
||||||
|
|
||||||
@ -1276,7 +1134,7 @@ class RWKV(MyModule):
|
|||||||
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
|
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
out = out.to(dtype=x.dtype) * g
|
out = out.to(dtype=x.dtype) * g
|
||||||
out = gemm(out, ow)
|
out = matmul(out, ow, omx, orx, omy, ory)
|
||||||
|
|
||||||
return x + out, xx, s
|
return x + out, xx, s
|
||||||
|
|
||||||
@ -1313,6 +1171,10 @@ class RWKV(MyModule):
|
|||||||
rrx,
|
rrx,
|
||||||
rmy,
|
rmy,
|
||||||
rry,
|
rry,
|
||||||
|
gmx,
|
||||||
|
grx,
|
||||||
|
gmy,
|
||||||
|
gry,
|
||||||
omx,
|
omx,
|
||||||
orx,
|
orx,
|
||||||
omy,
|
omy,
|
||||||
@ -1342,15 +1204,23 @@ class RWKV(MyModule):
|
|||||||
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
|
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
|
||||||
w = w[:, :, T - 1 :].reshape(H, T, T)
|
w = w[:, :, T - 1 :].reshape(H, T, T)
|
||||||
|
|
||||||
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
r = (
|
||||||
|
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
|
||||||
|
.view(T, H, S)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
k = (
|
k = (
|
||||||
gemm(kx, kw, output_dtype=torch.float32)
|
matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||||
.view(T, H, S)
|
.view(T, H, S)
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
.transpose(-2, -1)
|
.transpose(-2, -1)
|
||||||
)
|
)
|
||||||
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
v = (
|
||||||
g = F.silu(gemm(gx, gw))
|
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||||
|
.view(T, H, S)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
|
||||||
|
|
||||||
out = ((r @ k) * w) @ v + (r @ s) * wb
|
out = ((r @ k) * w) @ v + (r @ s) * wb
|
||||||
s = ws * s + (k * wk) @ v
|
s = ws * s + (k * wk) @ v
|
||||||
@ -1358,7 +1228,7 @@ class RWKV(MyModule):
|
|||||||
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
|
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
|
||||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||||
out = out.to(dtype=x.dtype) * g
|
out = out.to(dtype=x.dtype) * g
|
||||||
out = gemm(out, ow)
|
out = matmul(out, ow, omx, orx, omy, ory)
|
||||||
|
|
||||||
return x + out, xx[-1, :], s
|
return x + out, xx[-1, :], s
|
||||||
|
|
||||||
@ -1397,6 +1267,10 @@ class RWKV(MyModule):
|
|||||||
rrx,
|
rrx,
|
||||||
rmy,
|
rmy,
|
||||||
rry,
|
rry,
|
||||||
|
gmx,
|
||||||
|
grx,
|
||||||
|
gmy,
|
||||||
|
gry,
|
||||||
omx,
|
omx,
|
||||||
orx,
|
orx,
|
||||||
omy,
|
omy,
|
||||||
@ -1413,29 +1287,37 @@ class RWKV(MyModule):
|
|||||||
S = x.shape[-1] // H
|
S = x.shape[-1] // H
|
||||||
T = x.shape[0]
|
T = x.shape[0]
|
||||||
|
|
||||||
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
r = (
|
||||||
|
matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
|
||||||
|
.view(T, H, S)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
k = (
|
k = (
|
||||||
gemm(kx, kw, output_dtype=torch.float32)
|
matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||||
.view(T, H, S)
|
.view(T, H, S)
|
||||||
.transpose(0, 1)
|
.transpose(0, 1)
|
||||||
.transpose(-2, -1)
|
.transpose(-2, -1)
|
||||||
)
|
)
|
||||||
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
v = (
|
||||||
g = F.silu(gemm(gx, gw))
|
matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||||
|
.view(T, H, S)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
|
||||||
|
|
||||||
out = torch.empty((T, H, S), dtype=r.dtype, device=r.device)
|
out = torch.empty((T, H, S), dtype=r.dtype, device=r.device)
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
rt = r[:, t : t + 1, :]
|
rt = r[:, t : t + 1, :]
|
||||||
kt = k[:, :, t : t + 1]
|
kt = k[:, :, t : t + 1]
|
||||||
vt = v[:, t : t + 1, :]
|
vt = v[:, t : t + 1, :]
|
||||||
at = gemm(kt, vt)
|
at = matmul(kt, vt)
|
||||||
out[t] = (rt @ (t_first * at + s)).squeeze(1)
|
out[t] = (rt @ (t_first * at + s)).squeeze(1)
|
||||||
s = at + t_decay * s
|
s = at + t_decay * s
|
||||||
|
|
||||||
out = out.reshape(T, H * S)
|
out = out.reshape(T, H * S)
|
||||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||||
out = out.to(dtype=x.dtype) * g
|
out = out.to(dtype=x.dtype) * g
|
||||||
out = gemm(out, ow)
|
out = matmul(out, ow, omx, orx, omy, ory)
|
||||||
|
|
||||||
return x + out, xx[-1, :], s
|
return x + out, xx[-1, :], s
|
||||||
|
|
||||||
@ -1486,63 +1368,12 @@ class RWKV(MyModule):
|
|||||||
vx = xx * v_mix + sx * (1 - v_mix)
|
vx = xx * v_mix + sx * (1 - v_mix)
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||||||
|
|
||||||
r = torch.sigmoid(gemm(rx, rw))
|
r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
|
||||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||||
y, aa, bb, pp = cuda_wkv(T, aa.shape[0], t_decay, t_first, k, v, aa, bb, pp)
|
|
||||||
|
|
||||||
out = gemm(r * y.to(x.dtype), ow)
|
|
||||||
return x + out, xx[-1, :], aa, bb, pp
|
|
||||||
|
|
||||||
@MyFunction
|
|
||||||
def cuda_att_seq_i8(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
sx,
|
|
||||||
aa,
|
|
||||||
bb,
|
|
||||||
pp,
|
|
||||||
ln_w,
|
|
||||||
ln_b,
|
|
||||||
k_mix,
|
|
||||||
v_mix,
|
|
||||||
r_mix,
|
|
||||||
t_decay,
|
|
||||||
t_first,
|
|
||||||
kw,
|
|
||||||
vw,
|
|
||||||
rw,
|
|
||||||
ow,
|
|
||||||
kmx,
|
|
||||||
krx,
|
|
||||||
kmy,
|
|
||||||
kry,
|
|
||||||
vmx,
|
|
||||||
vrx,
|
|
||||||
vmy,
|
|
||||||
vry,
|
|
||||||
rmx,
|
|
||||||
rrx,
|
|
||||||
rmy,
|
|
||||||
rry,
|
|
||||||
omx,
|
|
||||||
orx,
|
|
||||||
omy,
|
|
||||||
ory,
|
|
||||||
):
|
|
||||||
T, C = x.shape
|
|
||||||
xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b)
|
|
||||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
|
|
||||||
kx = xx * k_mix + sx * (1 - k_mix)
|
|
||||||
vx = xx * v_mix + sx * (1 - v_mix)
|
|
||||||
rx = xx * r_mix + sx * (1 - r_mix)
|
|
||||||
|
|
||||||
r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry))
|
|
||||||
k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry)
|
|
||||||
v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)
|
|
||||||
y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp)
|
y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp)
|
||||||
|
|
||||||
out = self.mm8_seq(r * y, ow, omx, orx, omy, ory)
|
out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory)
|
||||||
return x + out, xx[-1, :], aa, bb, pp
|
return x + out, xx[-1, :], aa, bb, pp
|
||||||
|
|
||||||
# NOTE: decorate with @MyFunction causes JIT error
|
# NOTE: decorate with @MyFunction causes JIT error
|
||||||
@ -1578,6 +1409,10 @@ class RWKV(MyModule):
|
|||||||
rrx,
|
rrx,
|
||||||
rmy,
|
rmy,
|
||||||
rry,
|
rry,
|
||||||
|
gmx,
|
||||||
|
grx,
|
||||||
|
gmy,
|
||||||
|
gry,
|
||||||
omx,
|
omx,
|
||||||
orx,
|
orx,
|
||||||
omy,
|
omy,
|
||||||
@ -1594,10 +1429,10 @@ class RWKV(MyModule):
|
|||||||
N = x.shape[-1] // H
|
N = x.shape[-1] // H
|
||||||
T = x.shape[0]
|
T = x.shape[0]
|
||||||
|
|
||||||
r = gemm(rx, rw, output_dtype=torch.float32)
|
r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
|
||||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
|
||||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
|
||||||
g = F.silu(gemm(gx, gw))
|
g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
|
||||||
|
|
||||||
out, s = self.RUN_RWKV_5(
|
out, s = self.RUN_RWKV_5(
|
||||||
1,
|
1,
|
||||||
@ -1616,7 +1451,7 @@ class RWKV(MyModule):
|
|||||||
out = out.reshape(T, H * N)
|
out = out.reshape(T, H * N)
|
||||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||||
out = out.to(dtype=x.dtype) * g
|
out = out.to(dtype=x.dtype) * g
|
||||||
out = gemm(out, ow)
|
out = matmul(out, ow, omx, orx, omy, ory)
|
||||||
|
|
||||||
return x + out, xx[-1, :], s
|
return x + out, xx[-1, :], s
|
||||||
|
|
||||||
@ -1703,13 +1538,9 @@ class RWKV(MyModule):
|
|||||||
"RWKV_CUDA_ON"
|
"RWKV_CUDA_ON"
|
||||||
] == "1" and "cuda" in str(dev)
|
] == "1" and "cuda" in str(dev)
|
||||||
if cuda_applicable:
|
if cuda_applicable:
|
||||||
ATT = (
|
ATT = self.cuda_att_seq
|
||||||
self.cuda_att_seq
|
|
||||||
if wtype != torch.uint8
|
|
||||||
else self.cuda_att_seq_i8
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
|
ATT = self.att_seq
|
||||||
if self.version == 5:
|
if self.version == 5:
|
||||||
ATT = self.att_seq_v5
|
ATT = self.att_seq_v5
|
||||||
elif self.version == 5.1:
|
elif self.version == 5.1:
|
||||||
@ -1718,16 +1549,16 @@ class RWKV(MyModule):
|
|||||||
ATT = self.att_seq_v5_2
|
ATT = self.att_seq_v5_2
|
||||||
if cuda_applicable:
|
if cuda_applicable:
|
||||||
ATT = self.cuda_att_seq_v5_2
|
ATT = self.cuda_att_seq_v5_2
|
||||||
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
|
FFN = self.ffn_seq
|
||||||
else:
|
else:
|
||||||
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
|
ATT = self.att_one
|
||||||
if self.version == 5:
|
if self.version == 5:
|
||||||
ATT = self.att_one_v5
|
ATT = self.att_one_v5
|
||||||
elif self.version == 5.1:
|
elif self.version == 5.1:
|
||||||
ATT = self.att_one_v5_1
|
ATT = self.att_one_v5_1
|
||||||
elif self.version == 5.2:
|
elif self.version == 5.2:
|
||||||
ATT = self.att_one_v5_1 # same as v5.1
|
ATT = self.att_one_v5_1 # same as v5.1
|
||||||
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
|
FFN = self.ffn_one
|
||||||
|
|
||||||
x = x.to(dtype=atype, device=dev)
|
x = x.to(dtype=atype, device=dev)
|
||||||
|
|
||||||
@ -1872,6 +1703,10 @@ class RWKV(MyModule):
|
|||||||
rrx,
|
rrx,
|
||||||
rmy,
|
rmy,
|
||||||
rry,
|
rry,
|
||||||
|
gmx,
|
||||||
|
grx,
|
||||||
|
gmy,
|
||||||
|
gry,
|
||||||
omx,
|
omx,
|
||||||
orx,
|
orx,
|
||||||
omy,
|
omy,
|
||||||
@ -1944,7 +1779,7 @@ class RWKV(MyModule):
|
|||||||
x = x @ w["head.weight"]
|
x = x @ w["head.weight"]
|
||||||
else:
|
else:
|
||||||
if seq_mode and full_output:
|
if seq_mode and full_output:
|
||||||
x = self.mm8_seq(
|
x = mm8_seq(
|
||||||
x,
|
x,
|
||||||
w["head.weight"],
|
w["head.weight"],
|
||||||
w["head.weight_mx"],
|
w["head.weight_mx"],
|
||||||
@ -1953,7 +1788,7 @@ class RWKV(MyModule):
|
|||||||
w["head.weight_ry"],
|
w["head.weight_ry"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
x = self.mm8_one(
|
x = mm8_one(
|
||||||
x,
|
x,
|
||||||
w["head.weight"],
|
w["head.weight"],
|
||||||
w["head.weight_mx"],
|
w["head.weight_mx"],
|
||||||
|
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
Binary file not shown.
1
backend-python/rwkv_pip/utils.py
vendored
1
backend-python/rwkv_pip/utils.py
vendored
@ -81,6 +81,7 @@ class PIPELINE:
|
|||||||
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
||||||
probs = F.softmax(logits.float(), dim=-1)
|
probs = F.softmax(logits.float(), dim=-1)
|
||||||
top_k = int(top_k)
|
top_k = int(top_k)
|
||||||
|
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
|
||||||
if probs.device.type in ["cpu", "privateuseone"]:
|
if probs.device.type in ["cpu", "privateuseone"]:
|
||||||
probs = probs.cpu().numpy()
|
probs = probs.cpu().numpy()
|
||||||
sorted_ids = np.argsort(probs)
|
sorted_ids = np.argsort(probs)
|
||||||
|
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
Binary file not shown.
Loading…
Reference in New Issue
Block a user