upgrade cuda-beta
This commit is contained in:
parent
c4042bbfd8
commit
df969fcfc6
10
backend-python/rwkv_pip/beta/cuda/att_one.cu
vendored
10
backend-python/rwkv_pip/beta/cuda/att_one.cu
vendored
@ -88,7 +88,7 @@ struct Mix {
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
||||
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
Tensor v_mix, Tensor r_mix, Tensor kw,
|
||||
@ -105,9 +105,9 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
data_ptr<half>(vx), data_ptr<half>(rx)},
|
||||
x.numel());
|
||||
|
||||
gemm_fp16_cublas(kx, kw, k);
|
||||
gemm_fp16_cublas(vx, vw, v);
|
||||
gemm_fp16_cublas(rx, rw, r);
|
||||
gemm_fp16_cublas_tensor(kx, kw, k);
|
||||
gemm_fp16_cublas_tensor(vx, vw, v);
|
||||
gemm_fp16_cublas_tensor(rx, rw, r);
|
||||
at::sigmoid_(r);
|
||||
|
||||
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
|
||||
@ -118,7 +118,7 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
data_ptr<half>(r)},
|
||||
x.numel());
|
||||
|
||||
gemm_fp16_cublas(r, ow, x_plus_out);
|
||||
gemm_fp16_cublas_tensor(r, ow, x_plus_out);
|
||||
x_plus_out += x;
|
||||
return xx;
|
||||
}
|
||||
|
109
backend-python/rwkv_pip/beta/cuda/att_one_v5.cu
vendored
Normal file
109
backend-python/rwkv_pip/beta/cuda/att_one_v5.cu
vendored
Normal file
@ -0,0 +1,109 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "element_wise.h"
|
||||
#include "util.h"
|
||||
|
||||
// Equivalent Python code:
|
||||
// s1 = t_first * a + s
|
||||
// s2 = a + t_decay * s
|
||||
struct Fused1 {
|
||||
const float *t_first;
|
||||
const float *t_decay;
|
||||
const float *a;
|
||||
const float *s;
|
||||
const int32_t inner_size;
|
||||
/* out */ float *s1;
|
||||
/* out */ float *s2;
|
||||
|
||||
__device__ void operator()(int i) const {
|
||||
const int j = i / inner_size;
|
||||
s1[i] = t_first[j] * a[i] + s[i];
|
||||
s2[i] = a[i] + t_decay[j] * s[i];
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
Equivalent Python code:
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
*/
|
||||
|
||||
struct Mix {
|
||||
const half *xx;
|
||||
const half *sx;
|
||||
const half *k_mix;
|
||||
const half *v_mix;
|
||||
const half *r_mix;
|
||||
/* out */ half *kx;
|
||||
/* out */ half *vx;
|
||||
/* out */ half *rx;
|
||||
|
||||
__device__ void operator()(int i) const {
|
||||
half xx_ = xx[i];
|
||||
half sx_ = sx[i];
|
||||
half k_mix_ = k_mix[i];
|
||||
half v_mix_ = v_mix[i];
|
||||
half r_mix_ = r_mix[i];
|
||||
kx[i] = __hadd(__hmul(xx_, k_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||||
vx[i] = __hadd(__hmul(xx_, v_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
|
||||
rx[i] = __hadd(__hmul(xx_, r_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||||
}
|
||||
};
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
|
||||
Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix,
|
||||
Tensor r_mix, Tensor kw,
|
||||
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx,
|
||||
Tensor rw,
|
||||
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
|
||||
/* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v,
|
||||
/* imm */ Tensor r, /* imm */ Tensor s1,
|
||||
/* out */ Tensor x_plus_out, /* out */ Tensor s2) {
|
||||
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||||
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
|
||||
data_ptr<half>(k_mix), data_ptr<half>(v_mix),
|
||||
data_ptr<half>(r_mix), data_ptr<half>(kx),
|
||||
data_ptr<half>(vx), data_ptr<half>(rx)},
|
||||
x.numel());
|
||||
|
||||
int H = t_decay.size(0);
|
||||
int S = x.size(-1) / H;
|
||||
gemm_fp16_cublas_tensor(rx, rw, r);
|
||||
r = at::reshape(r, {H, 1, S});
|
||||
gemm_fp16_cublas_tensor(kx, kw, k);
|
||||
k = at::reshape(k, {H, S, 1});
|
||||
gemm_fp16_cublas_tensor(vx, vw, v);
|
||||
v = at::reshape(v, {H, 1, S});
|
||||
|
||||
{
|
||||
Tensor a = at::matmul(k, v);
|
||||
|
||||
// s1 = t_first * a + s
|
||||
// s2 = a + t_decay * s
|
||||
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay),
|
||||
data_ptr<float>(a), data_ptr<float>(s),
|
||||
static_cast<int32_t>(a.size(1) * a.size(2)),
|
||||
data_ptr<float>(s1), data_ptr<float>(s2)},
|
||||
a.numel());
|
||||
}
|
||||
|
||||
Tensor out = at::matmul(r, s1);
|
||||
out = at::flatten(out);
|
||||
out = at::squeeze(at::group_norm(at::unsqueeze(out, 0), H, lx_w, lx_b), 0);
|
||||
out = at::_cast_Half(out);
|
||||
|
||||
gemm_fp16_cublas_tensor(out, ow, x_plus_out);
|
||||
x_plus_out += x;
|
||||
return xx;
|
||||
}
|
1
backend-python/rwkv_pip/beta/cuda/att_seq.cu
vendored
1
backend-python/rwkv_pip/beta/cuda/att_seq.cu
vendored
@ -8,7 +8,6 @@
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
||||
void gemm_fp16_cublas(const void *a, const void *b, void *c, int m,
|
||||
int n, int k, bool output_fp32);
|
||||
|
||||
|
@ -70,11 +70,59 @@ void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m,
|
||||
cuda_c_data_type, cublas_ldc, compute_type, algo));
|
||||
}
|
||||
|
||||
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||
// comptiable with rwkv one mode, 1-D tensor * 2-D tensor
|
||||
const int m = a.dense_dim() == 1 ? 1 : a.size(0);
|
||||
const int n = b.size(1);
|
||||
const int k = b.size(0);
|
||||
gemm_fp16_cublas(a.data_ptr(), b.data_ptr(), c.data_ptr(), m, n, k,
|
||||
c.dtype() == torch::kFloat32);
|
||||
/*
|
||||
NOTE: blas gemm is column-major by default, but we need row-major output.
|
||||
The data of row-major, transposed matrix is exactly the same as the
|
||||
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
||||
*/
|
||||
void gemm_fp16_cublas_tensor(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||
if (a.sizes().size() == 1) {
|
||||
assert(b.sizes().size() == 2);
|
||||
a = at::unsqueeze(a, 0);
|
||||
}
|
||||
const auto cuda_data_type = CUDA_R_16F;
|
||||
const auto cuda_c_data_type =
|
||||
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
||||
const auto compute_type = CUDA_R_32F;
|
||||
const float sp_alpha = 1.f;
|
||||
// swap a and b, and use CUBLAS_OP_N. see the notes above
|
||||
std::swap(a, b);
|
||||
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
|
||||
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
|
||||
// m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
|
||||
// negative axis is used because of the existence of batch matmul.
|
||||
const int m = a.size(-1);
|
||||
const int k = a.size(-2);
|
||||
const int n = b.size(-2);
|
||||
const int cublas_lda = m;
|
||||
const int cublas_ldb = k;
|
||||
const int cublas_ldc = m;
|
||||
cublasHandle_t cublas_handle = get_cublas_handle();
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
#else
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
||||
#endif
|
||||
const float sp_beta = 0.f;
|
||||
if (a.sizes().size() == 2 && b.sizes().size() == 2) {
|
||||
CUBLAS_CHECK(cublasGemmEx(
|
||||
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
|
||||
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
|
||||
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
|
||||
compute_type, algo));
|
||||
} else {
|
||||
// batch matmul
|
||||
assert(a.sizes().size() == 3 && b.sizes().size() == 3);
|
||||
|
||||
const long long int cublas_stride_a = m * k;
|
||||
const long long int cublas_stride_b = k * n;
|
||||
const long long int cublas_stride_c = m * n;
|
||||
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
|
||||
cublas_handle, cublas_trans_a, cublas_trans_b, m,
|
||||
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
|
||||
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
|
||||
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
|
||||
a.size(0), compute_type, algo));
|
||||
}
|
||||
}
|
||||
|
20
backend-python/rwkv_pip/beta/cuda/wrapper.cpp
vendored
20
backend-python/rwkv_pip/beta/cuda/wrapper.cpp
vendored
@ -118,7 +118,9 @@ void mm8_one(int64_t N, int64_t M,
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
||||
#ifndef DISABLE_CUBLAS_GEMM
|
||||
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||||
#endif
|
||||
|
||||
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
Tensor v_mix, Tensor r_mix, Tensor kw,
|
||||
@ -134,6 +136,16 @@ Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
|
||||
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out);
|
||||
|
||||
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
|
||||
Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix,
|
||||
Tensor r_mix, Tensor kw,
|
||||
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx,
|
||||
Tensor rw,
|
||||
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
|
||||
/* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v,
|
||||
/* imm */ Tensor r, /* imm */ Tensor s1,
|
||||
/* out */ Tensor x_plus_out, /* out */ Tensor s2);
|
||||
|
||||
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||||
/* imm */ Tensor buf,
|
||||
@ -148,8 +160,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
||||
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
||||
m.def("mm8_one", &mm8_one, "mm8 one");
|
||||
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
|
||||
m.def("gemm_fp16_cublas", &gemm_fp16_cublas_tensor, "gemv fp16 cublas");
|
||||
m.def("att_one", &att_one, "att one");
|
||||
m.def("att_one_v5", &att_one_v5, "att one v5");
|
||||
m.def("att_seq", &att_seq, "att seq");
|
||||
m.def("ffn_seq", &ffn_seq, "ffn seq");
|
||||
m.def("ffn_one", &ffn_one, "ffn one");
|
||||
@ -159,8 +172,9 @@ TORCH_LIBRARY(rwkv, m) {
|
||||
m.def("wkv_forward", wkv_forward);
|
||||
m.def("mm8_seq", mm8_seq);
|
||||
m.def("mm8_one", mm8_one);
|
||||
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
|
||||
m.def("gemm_fp16_cublas", gemm_fp16_cublas_tensor);
|
||||
m.def("att_one", att_one);
|
||||
m.def("att_one_v5", &att_one_v5);
|
||||
m.def("att_seq", att_seq);
|
||||
m.def("ffn_seq", ffn_seq);
|
||||
m.def("ffn_one", ffn_one);
|
||||
|
359
backend-python/rwkv_pip/beta/model.py
vendored
359
backend-python/rwkv_pip/beta/model.py
vendored
@ -3,7 +3,7 @@
|
||||
########################################################################################################
|
||||
|
||||
from typing import Optional
|
||||
import types, gc, os, time, re
|
||||
import types, gc, os, time, re, platform
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
@ -91,6 +91,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||
f"{current_path}/cuda/att_one.cu",
|
||||
f"{current_path}/cuda/att_seq.cu",
|
||||
f"{current_path}/cuda/ffn.cu",
|
||||
f"{current_path}/cuda/att_one_v5.cu",
|
||||
],
|
||||
verbose=True,
|
||||
extra_cuda_cflags=[
|
||||
@ -149,26 +150,40 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||
torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y)
|
||||
return y.to(dtype=x.dtype)
|
||||
|
||||
else:
|
||||
os.environ["RWKV_CUDA_ON"] = "0"
|
||||
|
||||
if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||
|
||||
@MyStatic
|
||||
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||
if output_dtype is None:
|
||||
output_dtype = a.dtype
|
||||
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
|
||||
assert len(b.shape) == 2
|
||||
if len(a.shape) == 1:
|
||||
assert len(b.shape) == 2
|
||||
c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device)
|
||||
a = a.unsqueeze(0)
|
||||
else:
|
||||
assert len(a.shape) == len(b.shape)
|
||||
assert len(a.shape) == 2 or len(a.shape) == 3
|
||||
# torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit
|
||||
if len(a.shape) == 2:
|
||||
c = torch.empty(
|
||||
(a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device
|
||||
)
|
||||
else:
|
||||
c = torch.empty(
|
||||
(a.shape[0], a.shape[1], b.shape[-1]),
|
||||
dtype=output_dtype,
|
||||
device=a.device,
|
||||
)
|
||||
torch.ops.rwkv.gemm_fp16_cublas(a, b, c)
|
||||
return c
|
||||
else:
|
||||
return (a @ b).to(output_dtype)
|
||||
|
||||
else:
|
||||
os.environ["RWKV_CUDA_ON"] = "0"
|
||||
|
||||
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||
if output_dtype is None:
|
||||
@ -217,7 +232,7 @@ class RWKV(MyModule):
|
||||
) # load model to CPU first
|
||||
# it is supported to load a pure meta-tensor state dict (e.g. for quick testing)
|
||||
for k, v in self.w.items():
|
||||
if v.is_meta:
|
||||
if isinstance(v, torch.Tensor) and v.is_meta:
|
||||
# torch.zeros_like(v, device='cpu') doesn't produce an all-zero tensor
|
||||
# if v is a meta tensor
|
||||
self.w[k] = torch.zeros(v.shape, dtype=v.dtype, device="cpu")
|
||||
@ -247,9 +262,14 @@ class RWKV(MyModule):
|
||||
args.n_embd = w["emb.weight"].shape[1]
|
||||
args.n_layer = 0
|
||||
keys = list(w.keys())
|
||||
self.version = 4
|
||||
for x in keys:
|
||||
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
|
||||
args.n_layer = max(args.n_layer, layer_id + 1)
|
||||
if "ln_x" in x:
|
||||
self.version = 5
|
||||
if self.version == 5 and "att.time_decay" in x:
|
||||
args.n_head = w[x].shape[0]
|
||||
|
||||
####################### Compute strategy
|
||||
|
||||
@ -352,6 +372,20 @@ class RWKV(MyModule):
|
||||
del w["blocks.0.ln0.bias"]
|
||||
|
||||
print_need_newline = False
|
||||
|
||||
REAL_TIME_FIRST = False
|
||||
for x in list(w.keys()):
|
||||
if ".time_faaaa" in x:
|
||||
REAL_TIME_FIRST = True
|
||||
if REAL_TIME_FIRST:
|
||||
w = {
|
||||
k.replace(".time_faaaa", ".time_first")
|
||||
if ".time_faaaa" in k
|
||||
else k: v
|
||||
for k, v in w.items()
|
||||
}
|
||||
self.w = w
|
||||
|
||||
keys = list(w.keys())
|
||||
for x in keys:
|
||||
w[x].requires_grad = False
|
||||
@ -382,8 +416,19 @@ class RWKV(MyModule):
|
||||
w[x] = w[x].t()
|
||||
|
||||
if ".time_decay" in x: # need fp32 for this
|
||||
if self.version == 4:
|
||||
w[x] = -torch.exp(w[x].float())
|
||||
elif self.version == 5:
|
||||
w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1)
|
||||
elif ".time_first" in x: # need fp32 for this
|
||||
if self.version == 4:
|
||||
w[x] = w[x].float()
|
||||
elif self.version == 5:
|
||||
if REAL_TIME_FIRST:
|
||||
w[x] = w[x].float().reshape(-1, 1, 1)
|
||||
else:
|
||||
w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1)
|
||||
elif ".ln_x" in x: # need fp32 for group_norm
|
||||
w[x] = w[x].float()
|
||||
else:
|
||||
if (len(w[x].shape) == 2) and ("emb" not in x):
|
||||
@ -931,6 +976,147 @@ class RWKV(MyModule):
|
||||
|
||||
########################################################################################################
|
||||
|
||||
@MyFunction
|
||||
def att_one_v5(
|
||||
self,
|
||||
x,
|
||||
sx,
|
||||
s,
|
||||
ln_w,
|
||||
ln_b,
|
||||
lx_w,
|
||||
lx_b,
|
||||
k_mix,
|
||||
v_mix,
|
||||
r_mix,
|
||||
t_decay,
|
||||
t_first,
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
ow,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
ory,
|
||||
):
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
H = t_decay.shape[0]
|
||||
S = x.shape[-1] // H
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
|
||||
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
|
||||
|
||||
a = gemm(k, v)
|
||||
out = r @ (t_first * a + s)
|
||||
s = a + t_decay * s
|
||||
|
||||
out = out.flatten()
|
||||
out = F.group_norm(
|
||||
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
|
||||
).squeeze(0)
|
||||
out = out.to(dtype=x.dtype)
|
||||
out = gemm(out, ow)
|
||||
|
||||
return x + out, xx, s
|
||||
|
||||
@MyFunction
|
||||
def att_seq_v5(
|
||||
self,
|
||||
x,
|
||||
sx,
|
||||
s,
|
||||
ln_w,
|
||||
ln_b,
|
||||
lx_w,
|
||||
lx_b,
|
||||
k_mix,
|
||||
v_mix,
|
||||
r_mix,
|
||||
t_decay,
|
||||
t_first,
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
ow,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
ory,
|
||||
):
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
H = t_decay.shape[0]
|
||||
S = x.shape[-1] // H
|
||||
T = x.shape[0]
|
||||
|
||||
w = t_decay.reshape(-1, 1)
|
||||
u = t_first.reshape(-1, 1)
|
||||
ws = w.pow(T).reshape(H, 1, 1)
|
||||
ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1)
|
||||
w = w.repeat(1, T).pow(ind)
|
||||
wk = w.reshape(H, 1, T)
|
||||
wb = wk.transpose(-2, -1).flip(1)
|
||||
w = torch.cat([w[:, 1:], u], dim=1)
|
||||
w = F.pad(w, (0, T))
|
||||
w = torch.tile(w, [T])
|
||||
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
|
||||
w = w[:, :, T - 1 :].reshape(H, T, T)
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
k = (
|
||||
gemm(kx, kw, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
.transpose(-2, -1)
|
||||
)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
|
||||
out = ((r @ k) * w) @ v + (r @ s) * wb
|
||||
s = ws * s + (k * wk) @ v
|
||||
|
||||
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
|
||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||
out = out.to(dtype=x.dtype)
|
||||
out = gemm(out, ow)
|
||||
|
||||
return x + out, xx[-1, :], s
|
||||
|
||||
########################################################################################################
|
||||
|
||||
if os.environ["RWKV_CUDA_ON"] == "1":
|
||||
|
||||
@MyFunction
|
||||
@ -1140,7 +1326,7 @@ class RWKV(MyModule):
|
||||
xx = torch.ops.rwkv.ffn_seq(
|
||||
x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, buf, x_plus_out
|
||||
)
|
||||
return x_plus_out, xx[-1:]
|
||||
return x_plus_out, xx[-1, :]
|
||||
|
||||
@MyFunction
|
||||
def cuda_att_one_fp16(
|
||||
@ -1220,6 +1406,86 @@ class RWKV(MyModule):
|
||||
)
|
||||
return x_plus_out_t, xx, t1_t, t2_t, p_t
|
||||
|
||||
@MyFunction
|
||||
def cuda_att_one_v5_fp16(
|
||||
self,
|
||||
x,
|
||||
sx,
|
||||
s,
|
||||
ln_w,
|
||||
ln_b,
|
||||
lx_w,
|
||||
lx_b,
|
||||
k_mix,
|
||||
v_mix,
|
||||
r_mix,
|
||||
t_decay,
|
||||
t_first,
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
ow,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
ory,
|
||||
):
|
||||
kx = torch.empty_like(x)
|
||||
vx = torch.empty_like(x)
|
||||
rx = torch.empty_like(x)
|
||||
|
||||
H = t_decay.shape[0]
|
||||
S = x.shape[-1] // H
|
||||
|
||||
r = torch.empty((H * S,), dtype=torch.float32, device=x.device)
|
||||
k = torch.empty((H * S,), dtype=torch.float32, device=x.device)
|
||||
v = torch.empty((H * S,), dtype=torch.float32, device=x.device)
|
||||
s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
|
||||
s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
|
||||
x_plus_out = torch.empty_like(x)
|
||||
|
||||
xx = torch.ops.rwkv.att_one_v5(
|
||||
x,
|
||||
sx,
|
||||
s,
|
||||
ln_w,
|
||||
ln_b,
|
||||
lx_w,
|
||||
lx_b,
|
||||
k_mix,
|
||||
v_mix,
|
||||
r_mix,
|
||||
kw,
|
||||
kx,
|
||||
vw,
|
||||
vx,
|
||||
rw,
|
||||
rx,
|
||||
ow,
|
||||
t_first,
|
||||
k,
|
||||
t_decay,
|
||||
v,
|
||||
r,
|
||||
s1,
|
||||
x_plus_out,
|
||||
s2,
|
||||
)
|
||||
|
||||
return x_plus_out, xx, s2
|
||||
|
||||
@MyFunction
|
||||
def cuda_ffn_one_fp16(
|
||||
self,
|
||||
@ -1265,6 +1531,7 @@ class RWKV(MyModule):
|
||||
args = self.args
|
||||
|
||||
if state == None:
|
||||
if self.version == 4:
|
||||
state = [None] * args.n_layer * 5
|
||||
for i in range(
|
||||
args.n_layer
|
||||
@ -1276,10 +1543,16 @@ class RWKV(MyModule):
|
||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||
).contiguous()
|
||||
state[i * 5 + 1] = torch.zeros(
|
||||
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
|
||||
args.n_embd,
|
||||
dtype=torch.float,
|
||||
requires_grad=False,
|
||||
device=dev,
|
||||
).contiguous()
|
||||
state[i * 5 + 2] = torch.zeros(
|
||||
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
|
||||
args.n_embd,
|
||||
dtype=torch.float,
|
||||
requires_grad=False,
|
||||
device=dev,
|
||||
).contiguous()
|
||||
state[i * 5 + 3] = (
|
||||
torch.zeros(
|
||||
@ -1293,6 +1566,28 @@ class RWKV(MyModule):
|
||||
state[i * 5 + 4] = torch.zeros(
|
||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||
).contiguous()
|
||||
elif self.version == 5:
|
||||
state = [None] * args.n_layer * 3
|
||||
for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx
|
||||
dd = self.strategy[i]
|
||||
dev = dd.device
|
||||
atype = dd.atype
|
||||
state[i * 3 + 0] = torch.zeros(
|
||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||
).contiguous()
|
||||
state[i * 3 + 1] = torch.zeros(
|
||||
(
|
||||
args.n_head,
|
||||
args.n_embd // args.n_head,
|
||||
args.n_embd // args.n_head,
|
||||
),
|
||||
dtype=torch.float,
|
||||
requires_grad=False,
|
||||
device=dev,
|
||||
).contiguous()
|
||||
state[i * 3 + 2] = torch.zeros(
|
||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||
).contiguous()
|
||||
|
||||
seq_mode = len(tokens) > 1
|
||||
|
||||
@ -1317,9 +1612,13 @@ class RWKV(MyModule):
|
||||
ATT = self.cuda_att_seq_i8
|
||||
else:
|
||||
ATT = self.cuda_att_seq_naive
|
||||
if self.version == 5:
|
||||
ATT = self.att_seq_v5
|
||||
else:
|
||||
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
|
||||
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
|
||||
if self.version == 5:
|
||||
ATT = self.att_one_v5
|
||||
if (
|
||||
"cuda" in str(dev)
|
||||
and os.environ["RWKV_CUDA_ON"] == "1"
|
||||
@ -1327,6 +1626,8 @@ class RWKV(MyModule):
|
||||
):
|
||||
ATT = self.cuda_att_one_fp16
|
||||
FFN = self.cuda_ffn_one_fp16
|
||||
if self.version == 5:
|
||||
ATT = self.cuda_att_one_v5_fp16
|
||||
|
||||
x = x.to(dtype=atype, device=dev)
|
||||
|
||||
@ -1355,6 +1656,7 @@ class RWKV(MyModule):
|
||||
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
|
||||
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
|
||||
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
|
||||
if self.version == 4:
|
||||
(
|
||||
x,
|
||||
state[i * 5 + 0],
|
||||
@ -1395,6 +1697,41 @@ class RWKV(MyModule):
|
||||
omy,
|
||||
ory,
|
||||
)
|
||||
elif self.version == 5:
|
||||
x, state[i * 3 + 0], state[i * 3 + 1] = ATT(
|
||||
x,
|
||||
state[i * 3 + 0],
|
||||
state[i * 3 + 1],
|
||||
w[f"{bbb}ln1.weight"],
|
||||
w[f"{bbb}ln1.bias"],
|
||||
w[f"{att}ln_x.weight"],
|
||||
w[f"{att}ln_x.bias"],
|
||||
w[f"{att}time_mix_k"],
|
||||
w[f"{att}time_mix_v"],
|
||||
w[f"{att}time_mix_r"],
|
||||
w[f"{att}time_decay"],
|
||||
w[f"{att}time_first"],
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
ow,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
ory,
|
||||
)
|
||||
if dd.stream:
|
||||
del kw, vw, rw, ow
|
||||
|
||||
@ -1417,9 +1754,13 @@ class RWKV(MyModule):
|
||||
rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x
|
||||
rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x
|
||||
rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x
|
||||
x, state[i * 5 + 4] = FFN(
|
||||
if self.version == 4:
|
||||
offset = i * 5 + 4
|
||||
elif self.version == 5:
|
||||
offset = i * 3 + 2
|
||||
x, state[offset] = FFN(
|
||||
x,
|
||||
state[i * 5 + 4],
|
||||
state[offset],
|
||||
w[f"{bbb}ln2.weight"],
|
||||
w[f"{bbb}ln2.bias"],
|
||||
w[f"{ffn}time_mix_k"],
|
||||
|
Loading…
Reference in New Issue
Block a user