upgrade cuda-beta

This commit is contained in:
josc146 2023-09-15 16:30:11 +08:00
parent c4042bbfd8
commit df969fcfc6
6 changed files with 601 additions and 90 deletions

View File

@ -88,7 +88,7 @@ struct Mix {
using torch::Tensor; using torch::Tensor;
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
Tensor v_mix, Tensor r_mix, Tensor kw, Tensor v_mix, Tensor r_mix, Tensor kw,
@ -105,9 +105,9 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
data_ptr<half>(vx), data_ptr<half>(rx)}, data_ptr<half>(vx), data_ptr<half>(rx)},
x.numel()); x.numel());
gemm_fp16_cublas(kx, kw, k); gemm_fp16_cublas_tensor(kx, kw, k);
gemm_fp16_cublas(vx, vw, v); gemm_fp16_cublas_tensor(vx, vw, v);
gemm_fp16_cublas(rx, rw, r); gemm_fp16_cublas_tensor(rx, rw, r);
at::sigmoid_(r); at::sigmoid_(r);
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k), element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
@ -118,7 +118,7 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
data_ptr<half>(r)}, data_ptr<half>(r)},
x.numel()); x.numel());
gemm_fp16_cublas(r, ow, x_plus_out); gemm_fp16_cublas_tensor(r, ow, x_plus_out);
x_plus_out += x; x_plus_out += x;
return xx; return xx;
} }

View File

@ -0,0 +1,109 @@
#include "ATen/ATen.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "element_wise.h"
#include "util.h"
// Equivalent Python code:
// s1 = t_first * a + s
// s2 = a + t_decay * s
struct Fused1 {
const float *t_first;
const float *t_decay;
const float *a;
const float *s;
const int32_t inner_size;
/* out */ float *s1;
/* out */ float *s2;
__device__ void operator()(int i) const {
const int j = i / inner_size;
s1[i] = t_first[j] * a[i] + s[i];
s2[i] = a[i] + t_decay[j] * s[i];
}
};
/*
Equivalent Python code:
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
*/
struct Mix {
const half *xx;
const half *sx;
const half *k_mix;
const half *v_mix;
const half *r_mix;
/* out */ half *kx;
/* out */ half *vx;
/* out */ half *rx;
__device__ void operator()(int i) const {
half xx_ = xx[i];
half sx_ = sx[i];
half k_mix_ = k_mix[i];
half v_mix_ = v_mix[i];
half r_mix_ = r_mix[i];
kx[i] = __hadd(__hmul(xx_, k_mix_),
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
vx[i] = __hadd(__hmul(xx_, v_mix_),
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
rx[i] = __hadd(__hmul(xx_, r_mix_),
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
}
};
using torch::Tensor;
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix,
Tensor r_mix, Tensor kw,
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx,
Tensor rw,
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
/* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v,
/* imm */ Tensor r, /* imm */ Tensor s1,
/* out */ Tensor x_plus_out, /* out */ Tensor s2) {
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
data_ptr<half>(k_mix), data_ptr<half>(v_mix),
data_ptr<half>(r_mix), data_ptr<half>(kx),
data_ptr<half>(vx), data_ptr<half>(rx)},
x.numel());
int H = t_decay.size(0);
int S = x.size(-1) / H;
gemm_fp16_cublas_tensor(rx, rw, r);
r = at::reshape(r, {H, 1, S});
gemm_fp16_cublas_tensor(kx, kw, k);
k = at::reshape(k, {H, S, 1});
gemm_fp16_cublas_tensor(vx, vw, v);
v = at::reshape(v, {H, 1, S});
{
Tensor a = at::matmul(k, v);
// s1 = t_first * a + s
// s2 = a + t_decay * s
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay),
data_ptr<float>(a), data_ptr<float>(s),
static_cast<int32_t>(a.size(1) * a.size(2)),
data_ptr<float>(s1), data_ptr<float>(s2)},
a.numel());
}
Tensor out = at::matmul(r, s1);
out = at::flatten(out);
out = at::squeeze(at::group_norm(at::unsqueeze(out, 0), H, lx_w, lx_b), 0);
out = at::_cast_Half(out);
gemm_fp16_cublas_tensor(out, ow, x_plus_out);
x_plus_out += x;
return xx;
}

View File

@ -8,7 +8,6 @@
using torch::Tensor; using torch::Tensor;
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
void gemm_fp16_cublas(const void *a, const void *b, void *c, int m, void gemm_fp16_cublas(const void *a, const void *b, void *c, int m,
int n, int k, bool output_fp32); int n, int k, bool output_fp32);

View File

@ -70,11 +70,59 @@ void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m,
cuda_c_data_type, cublas_ldc, compute_type, algo)); cuda_c_data_type, cublas_ldc, compute_type, algo));
} }
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { /*
// comptiable with rwkv one mode, 1-D tensor * 2-D tensor NOTE: blas gemm is column-major by default, but we need row-major output.
const int m = a.dense_dim() == 1 ? 1 : a.size(0); The data of row-major, transposed matrix is exactly the same as the
const int n = b.size(1); column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
const int k = b.size(0); */
gemm_fp16_cublas(a.data_ptr(), b.data_ptr(), c.data_ptr(), m, n, k, void gemm_fp16_cublas_tensor(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
c.dtype() == torch::kFloat32); if (a.sizes().size() == 1) {
assert(b.sizes().size() == 2);
a = at::unsqueeze(a, 0);
}
const auto cuda_data_type = CUDA_R_16F;
const auto cuda_c_data_type =
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
const auto compute_type = CUDA_R_32F;
const float sp_alpha = 1.f;
// swap a and b, and use CUBLAS_OP_N. see the notes above
std::swap(a, b);
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
// m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
// negative axis is used because of the existence of batch matmul.
const int m = a.size(-1);
const int k = a.size(-2);
const int n = b.size(-2);
const int cublas_lda = m;
const int cublas_ldb = k;
const int cublas_ldc = m;
cublasHandle_t cublas_handle = get_cublas_handle();
#if CUDA_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
#endif
const float sp_beta = 0.f;
if (a.sizes().size() == 2 && b.sizes().size() == 2) {
CUBLAS_CHECK(cublasGemmEx(
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
compute_type, algo));
} else {
// batch matmul
assert(a.sizes().size() == 3 && b.sizes().size() == 3);
const long long int cublas_stride_a = m * k;
const long long int cublas_stride_b = k * n;
const long long int cublas_stride_c = m * n;
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
cublas_handle, cublas_trans_a, cublas_trans_b, m,
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
a.size(0), compute_type, algo));
}
} }

View File

@ -118,7 +118,9 @@ void mm8_one(int64_t N, int64_t M,
using torch::Tensor; using torch::Tensor;
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); #ifndef DISABLE_CUBLAS_GEMM
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
#endif
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
Tensor v_mix, Tensor r_mix, Tensor kw, Tensor v_mix, Tensor r_mix, Tensor kw,
@ -134,6 +136,16 @@ Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb, Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out); Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out);
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix,
Tensor r_mix, Tensor kw,
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx,
Tensor rw,
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
/* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v,
/* imm */ Tensor r, /* imm */ Tensor s1,
/* out */ Tensor x_plus_out, /* out */ Tensor s2);
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
/* imm */ Tensor buf, /* imm */ Tensor buf,
@ -148,8 +160,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wkv_forward", &wkv_forward, "wkv forward"); m.def("wkv_forward", &wkv_forward, "wkv forward");
m.def("mm8_seq", &mm8_seq, "mm8 seq"); m.def("mm8_seq", &mm8_seq, "mm8 seq");
m.def("mm8_one", &mm8_one, "mm8 one"); m.def("mm8_one", &mm8_one, "mm8 one");
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas"); m.def("gemm_fp16_cublas", &gemm_fp16_cublas_tensor, "gemv fp16 cublas");
m.def("att_one", &att_one, "att one"); m.def("att_one", &att_one, "att one");
m.def("att_one_v5", &att_one_v5, "att one v5");
m.def("att_seq", &att_seq, "att seq"); m.def("att_seq", &att_seq, "att seq");
m.def("ffn_seq", &ffn_seq, "ffn seq"); m.def("ffn_seq", &ffn_seq, "ffn seq");
m.def("ffn_one", &ffn_one, "ffn one"); m.def("ffn_one", &ffn_one, "ffn one");
@ -159,8 +172,9 @@ TORCH_LIBRARY(rwkv, m) {
m.def("wkv_forward", wkv_forward); m.def("wkv_forward", wkv_forward);
m.def("mm8_seq", mm8_seq); m.def("mm8_seq", mm8_seq);
m.def("mm8_one", mm8_one); m.def("mm8_one", mm8_one);
m.def("gemm_fp16_cublas", gemm_fp16_cublas); m.def("gemm_fp16_cublas", gemm_fp16_cublas_tensor);
m.def("att_one", att_one); m.def("att_one", att_one);
m.def("att_one_v5", &att_one_v5);
m.def("att_seq", att_seq); m.def("att_seq", att_seq);
m.def("ffn_seq", ffn_seq); m.def("ffn_seq", ffn_seq);
m.def("ffn_one", ffn_one); m.def("ffn_one", ffn_one);

View File

@ -3,7 +3,7 @@
######################################################################################################## ########################################################################################################
from typing import Optional from typing import Optional
import types, gc, os, time, re import types, gc, os, time, re, platform
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
@ -91,6 +91,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
f"{current_path}/cuda/att_one.cu", f"{current_path}/cuda/att_one.cu",
f"{current_path}/cuda/att_seq.cu", f"{current_path}/cuda/att_seq.cu",
f"{current_path}/cuda/ffn.cu", f"{current_path}/cuda/ffn.cu",
f"{current_path}/cuda/att_one_v5.cu",
], ],
verbose=True, verbose=True,
extra_cuda_cflags=[ extra_cuda_cflags=[
@ -149,26 +150,40 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y)
return y.to(dtype=x.dtype) return y.to(dtype=x.dtype)
else:
os.environ["RWKV_CUDA_ON"] = "0"
if os.environ.get("RWKV_CUDA_ON") == "1":
@MyStatic @MyStatic
def gemm(a, b, output_dtype: Optional[torch.dtype] = None): def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
if output_dtype is None: if output_dtype is None:
output_dtype = a.dtype output_dtype = a.dtype
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda": if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
assert len(b.shape) == 2
if len(a.shape) == 1: if len(a.shape) == 1:
assert len(b.shape) == 2
c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device) c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device)
a = a.unsqueeze(0) a = a.unsqueeze(0)
else: else:
assert len(a.shape) == len(b.shape)
assert len(a.shape) == 2 or len(a.shape) == 3
# torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit
if len(a.shape) == 2:
c = torch.empty( c = torch.empty(
(a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device (a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device
) )
else:
c = torch.empty(
(a.shape[0], a.shape[1], b.shape[-1]),
dtype=output_dtype,
device=a.device,
)
torch.ops.rwkv.gemm_fp16_cublas(a, b, c) torch.ops.rwkv.gemm_fp16_cublas(a, b, c)
return c return c
else: else:
return (a @ b).to(output_dtype) return (a @ b).to(output_dtype)
else: else:
os.environ["RWKV_CUDA_ON"] = "0"
def gemm(a, b, output_dtype: Optional[torch.dtype] = None): def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
if output_dtype is None: if output_dtype is None:
@ -217,7 +232,7 @@ class RWKV(MyModule):
) # load model to CPU first ) # load model to CPU first
# it is supported to load a pure meta-tensor state dict (e.g. for quick testing) # it is supported to load a pure meta-tensor state dict (e.g. for quick testing)
for k, v in self.w.items(): for k, v in self.w.items():
if v.is_meta: if isinstance(v, torch.Tensor) and v.is_meta:
# torch.zeros_like(v, device='cpu') doesn't produce an all-zero tensor # torch.zeros_like(v, device='cpu') doesn't produce an all-zero tensor
# if v is a meta tensor # if v is a meta tensor
self.w[k] = torch.zeros(v.shape, dtype=v.dtype, device="cpu") self.w[k] = torch.zeros(v.shape, dtype=v.dtype, device="cpu")
@ -247,9 +262,14 @@ class RWKV(MyModule):
args.n_embd = w["emb.weight"].shape[1] args.n_embd = w["emb.weight"].shape[1]
args.n_layer = 0 args.n_layer = 0
keys = list(w.keys()) keys = list(w.keys())
self.version = 4
for x in keys: for x in keys:
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
args.n_layer = max(args.n_layer, layer_id + 1) args.n_layer = max(args.n_layer, layer_id + 1)
if "ln_x" in x:
self.version = 5
if self.version == 5 and "att.time_decay" in x:
args.n_head = w[x].shape[0]
####################### Compute strategy ####################### Compute strategy
@ -352,6 +372,20 @@ class RWKV(MyModule):
del w["blocks.0.ln0.bias"] del w["blocks.0.ln0.bias"]
print_need_newline = False print_need_newline = False
REAL_TIME_FIRST = False
for x in list(w.keys()):
if ".time_faaaa" in x:
REAL_TIME_FIRST = True
if REAL_TIME_FIRST:
w = {
k.replace(".time_faaaa", ".time_first")
if ".time_faaaa" in k
else k: v
for k, v in w.items()
}
self.w = w
keys = list(w.keys()) keys = list(w.keys())
for x in keys: for x in keys:
w[x].requires_grad = False w[x].requires_grad = False
@ -382,8 +416,19 @@ class RWKV(MyModule):
w[x] = w[x].t() w[x] = w[x].t()
if ".time_decay" in x: # need fp32 for this if ".time_decay" in x: # need fp32 for this
if self.version == 4:
w[x] = -torch.exp(w[x].float()) w[x] = -torch.exp(w[x].float())
elif self.version == 5:
w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1)
elif ".time_first" in x: # need fp32 for this elif ".time_first" in x: # need fp32 for this
if self.version == 4:
w[x] = w[x].float()
elif self.version == 5:
if REAL_TIME_FIRST:
w[x] = w[x].float().reshape(-1, 1, 1)
else:
w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1)
elif ".ln_x" in x: # need fp32 for group_norm
w[x] = w[x].float() w[x] = w[x].float()
else: else:
if (len(w[x].shape) == 2) and ("emb" not in x): if (len(w[x].shape) == 2) and ("emb" not in x):
@ -931,6 +976,147 @@ class RWKV(MyModule):
######################################################################################################## ########################################################################################################
@MyFunction
def att_one_v5(
self,
x,
sx,
s,
ln_w,
ln_b,
lx_w,
lx_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
H = t_decay.shape[0]
S = x.shape[-1] // H
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
a = gemm(k, v)
out = r @ (t_first * a + s)
s = a + t_decay * s
out = out.flatten()
out = F.group_norm(
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
).squeeze(0)
out = out.to(dtype=x.dtype)
out = gemm(out, ow)
return x + out, xx, s
@MyFunction
def att_seq_v5(
self,
x,
sx,
s,
ln_w,
ln_b,
lx_w,
lx_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
H = t_decay.shape[0]
S = x.shape[-1] // H
T = x.shape[0]
w = t_decay.reshape(-1, 1)
u = t_first.reshape(-1, 1)
ws = w.pow(T).reshape(H, 1, 1)
ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1)
w = w.repeat(1, T).pow(ind)
wk = w.reshape(H, 1, T)
wb = wk.transpose(-2, -1).flip(1)
w = torch.cat([w[:, 1:], u], dim=1)
w = F.pad(w, (0, T))
w = torch.tile(w, [T])
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
w = w[:, :, T - 1 :].reshape(H, T, T)
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
k = (
gemm(kx, kw, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
.transpose(-2, -1)
)
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
out = ((r @ k) * w) @ v + (r @ s) * wb
s = ws * s + (k * wk) @ v
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype)
out = gemm(out, ow)
return x + out, xx[-1, :], s
########################################################################################################
if os.environ["RWKV_CUDA_ON"] == "1": if os.environ["RWKV_CUDA_ON"] == "1":
@MyFunction @MyFunction
@ -1140,7 +1326,7 @@ class RWKV(MyModule):
xx = torch.ops.rwkv.ffn_seq( xx = torch.ops.rwkv.ffn_seq(
x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, buf, x_plus_out x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, buf, x_plus_out
) )
return x_plus_out, xx[-1:] return x_plus_out, xx[-1, :]
@MyFunction @MyFunction
def cuda_att_one_fp16( def cuda_att_one_fp16(
@ -1220,6 +1406,86 @@ class RWKV(MyModule):
) )
return x_plus_out_t, xx, t1_t, t2_t, p_t return x_plus_out_t, xx, t1_t, t2_t, p_t
@MyFunction
def cuda_att_one_v5_fp16(
self,
x,
sx,
s,
ln_w,
ln_b,
lx_w,
lx_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
kx = torch.empty_like(x)
vx = torch.empty_like(x)
rx = torch.empty_like(x)
H = t_decay.shape[0]
S = x.shape[-1] // H
r = torch.empty((H * S,), dtype=torch.float32, device=x.device)
k = torch.empty((H * S,), dtype=torch.float32, device=x.device)
v = torch.empty((H * S,), dtype=torch.float32, device=x.device)
s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
x_plus_out = torch.empty_like(x)
xx = torch.ops.rwkv.att_one_v5(
x,
sx,
s,
ln_w,
ln_b,
lx_w,
lx_b,
k_mix,
v_mix,
r_mix,
kw,
kx,
vw,
vx,
rw,
rx,
ow,
t_first,
k,
t_decay,
v,
r,
s1,
x_plus_out,
s2,
)
return x_plus_out, xx, s2
@MyFunction @MyFunction
def cuda_ffn_one_fp16( def cuda_ffn_one_fp16(
self, self,
@ -1265,6 +1531,7 @@ class RWKV(MyModule):
args = self.args args = self.args
if state == None: if state == None:
if self.version == 4:
state = [None] * args.n_layer * 5 state = [None] * args.n_layer * 5
for i in range( for i in range(
args.n_layer args.n_layer
@ -1276,10 +1543,16 @@ class RWKV(MyModule):
args.n_embd, dtype=atype, requires_grad=False, device=dev args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous() ).contiguous()
state[i * 5 + 1] = torch.zeros( state[i * 5 + 1] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev args.n_embd,
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous() ).contiguous()
state[i * 5 + 2] = torch.zeros( state[i * 5 + 2] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev args.n_embd,
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous() ).contiguous()
state[i * 5 + 3] = ( state[i * 5 + 3] = (
torch.zeros( torch.zeros(
@ -1293,6 +1566,28 @@ class RWKV(MyModule):
state[i * 5 + 4] = torch.zeros( state[i * 5 + 4] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous() ).contiguous()
elif self.version == 5:
state = [None] * args.n_layer * 3
for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
state[i * 3 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state[i * 3 + 1] = torch.zeros(
(
args.n_head,
args.n_embd // args.n_head,
args.n_embd // args.n_head,
),
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
state[i * 3 + 2] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
seq_mode = len(tokens) > 1 seq_mode = len(tokens) > 1
@ -1317,9 +1612,13 @@ class RWKV(MyModule):
ATT = self.cuda_att_seq_i8 ATT = self.cuda_att_seq_i8
else: else:
ATT = self.cuda_att_seq_naive ATT = self.cuda_att_seq_naive
if self.version == 5:
ATT = self.att_seq_v5
else: else:
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
if self.version == 5:
ATT = self.att_one_v5
if ( if (
"cuda" in str(dev) "cuda" in str(dev)
and os.environ["RWKV_CUDA_ON"] == "1" and os.environ["RWKV_CUDA_ON"] == "1"
@ -1327,6 +1626,8 @@ class RWKV(MyModule):
): ):
ATT = self.cuda_att_one_fp16 ATT = self.cuda_att_one_fp16
FFN = self.cuda_ffn_one_fp16 FFN = self.cuda_ffn_one_fp16
if self.version == 5:
ATT = self.cuda_att_one_v5_fp16
x = x.to(dtype=atype, device=dev) x = x.to(dtype=atype, device=dev)
@ -1355,6 +1656,7 @@ class RWKV(MyModule):
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
if self.version == 4:
( (
x, x,
state[i * 5 + 0], state[i * 5 + 0],
@ -1395,6 +1697,41 @@ class RWKV(MyModule):
omy, omy,
ory, ory,
) )
elif self.version == 5:
x, state[i * 3 + 0], state[i * 3 + 1] = ATT(
x,
state[i * 3 + 0],
state[i * 3 + 1],
w[f"{bbb}ln1.weight"],
w[f"{bbb}ln1.bias"],
w[f"{att}ln_x.weight"],
w[f"{att}ln_x.bias"],
w[f"{att}time_mix_k"],
w[f"{att}time_mix_v"],
w[f"{att}time_mix_r"],
w[f"{att}time_decay"],
w[f"{att}time_first"],
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
)
if dd.stream: if dd.stream:
del kw, vw, rw, ow del kw, vw, rw, ow
@ -1417,9 +1754,13 @@ class RWKV(MyModule):
rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x
rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x
rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x
x, state[i * 5 + 4] = FFN( if self.version == 4:
offset = i * 5 + 4
elif self.version == 5:
offset = i * 3 + 2
x, state[offset] = FFN(
x, x,
state[i * 5 + 4], state[offset],
w[f"{bbb}ln2.weight"], w[f"{bbb}ln2.weight"],
w[f"{bbb}ln2.bias"], w[f"{bbb}ln2.bias"],
w[f"{ffn}time_mix_k"], w[f"{ffn}time_mix_k"],