upgrade cuda-beta
This commit is contained in:
parent
c4042bbfd8
commit
df969fcfc6
10
backend-python/rwkv_pip/beta/cuda/att_one.cu
vendored
10
backend-python/rwkv_pip/beta/cuda/att_one.cu
vendored
@ -88,7 +88,7 @@ struct Mix {
|
|||||||
|
|
||||||
using torch::Tensor;
|
using torch::Tensor;
|
||||||
|
|
||||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||||||
|
|
||||||
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||||
Tensor v_mix, Tensor r_mix, Tensor kw,
|
Tensor v_mix, Tensor r_mix, Tensor kw,
|
||||||
@ -105,9 +105,9 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
|||||||
data_ptr<half>(vx), data_ptr<half>(rx)},
|
data_ptr<half>(vx), data_ptr<half>(rx)},
|
||||||
x.numel());
|
x.numel());
|
||||||
|
|
||||||
gemm_fp16_cublas(kx, kw, k);
|
gemm_fp16_cublas_tensor(kx, kw, k);
|
||||||
gemm_fp16_cublas(vx, vw, v);
|
gemm_fp16_cublas_tensor(vx, vw, v);
|
||||||
gemm_fp16_cublas(rx, rw, r);
|
gemm_fp16_cublas_tensor(rx, rw, r);
|
||||||
at::sigmoid_(r);
|
at::sigmoid_(r);
|
||||||
|
|
||||||
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
|
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
|
||||||
@ -118,7 +118,7 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
|||||||
data_ptr<half>(r)},
|
data_ptr<half>(r)},
|
||||||
x.numel());
|
x.numel());
|
||||||
|
|
||||||
gemm_fp16_cublas(r, ow, x_plus_out);
|
gemm_fp16_cublas_tensor(r, ow, x_plus_out);
|
||||||
x_plus_out += x;
|
x_plus_out += x;
|
||||||
return xx;
|
return xx;
|
||||||
}
|
}
|
||||||
|
109
backend-python/rwkv_pip/beta/cuda/att_one_v5.cu
vendored
Normal file
109
backend-python/rwkv_pip/beta/cuda/att_one_v5.cu
vendored
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
#include "ATen/ATen.h"
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include "element_wise.h"
|
||||||
|
#include "util.h"
|
||||||
|
|
||||||
|
// Equivalent Python code:
|
||||||
|
// s1 = t_first * a + s
|
||||||
|
// s2 = a + t_decay * s
|
||||||
|
struct Fused1 {
|
||||||
|
const float *t_first;
|
||||||
|
const float *t_decay;
|
||||||
|
const float *a;
|
||||||
|
const float *s;
|
||||||
|
const int32_t inner_size;
|
||||||
|
/* out */ float *s1;
|
||||||
|
/* out */ float *s2;
|
||||||
|
|
||||||
|
__device__ void operator()(int i) const {
|
||||||
|
const int j = i / inner_size;
|
||||||
|
s1[i] = t_first[j] * a[i] + s[i];
|
||||||
|
s2[i] = a[i] + t_decay[j] * s[i];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
Equivalent Python code:
|
||||||
|
kx = xx * k_mix + sx * (1 - k_mix)
|
||||||
|
vx = xx * v_mix + sx * (1 - v_mix)
|
||||||
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct Mix {
|
||||||
|
const half *xx;
|
||||||
|
const half *sx;
|
||||||
|
const half *k_mix;
|
||||||
|
const half *v_mix;
|
||||||
|
const half *r_mix;
|
||||||
|
/* out */ half *kx;
|
||||||
|
/* out */ half *vx;
|
||||||
|
/* out */ half *rx;
|
||||||
|
|
||||||
|
__device__ void operator()(int i) const {
|
||||||
|
half xx_ = xx[i];
|
||||||
|
half sx_ = sx[i];
|
||||||
|
half k_mix_ = k_mix[i];
|
||||||
|
half v_mix_ = v_mix[i];
|
||||||
|
half r_mix_ = r_mix[i];
|
||||||
|
kx[i] = __hadd(__hmul(xx_, k_mix_),
|
||||||
|
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||||||
|
vx[i] = __hadd(__hmul(xx_, v_mix_),
|
||||||
|
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
|
||||||
|
rx[i] = __hadd(__hmul(xx_, r_mix_),
|
||||||
|
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using torch::Tensor;
|
||||||
|
|
||||||
|
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||||||
|
|
||||||
|
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
|
||||||
|
Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix,
|
||||||
|
Tensor r_mix, Tensor kw,
|
||||||
|
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx,
|
||||||
|
Tensor rw,
|
||||||
|
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
|
||||||
|
/* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v,
|
||||||
|
/* imm */ Tensor r, /* imm */ Tensor s1,
|
||||||
|
/* out */ Tensor x_plus_out, /* out */ Tensor s2) {
|
||||||
|
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||||||
|
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
|
||||||
|
data_ptr<half>(k_mix), data_ptr<half>(v_mix),
|
||||||
|
data_ptr<half>(r_mix), data_ptr<half>(kx),
|
||||||
|
data_ptr<half>(vx), data_ptr<half>(rx)},
|
||||||
|
x.numel());
|
||||||
|
|
||||||
|
int H = t_decay.size(0);
|
||||||
|
int S = x.size(-1) / H;
|
||||||
|
gemm_fp16_cublas_tensor(rx, rw, r);
|
||||||
|
r = at::reshape(r, {H, 1, S});
|
||||||
|
gemm_fp16_cublas_tensor(kx, kw, k);
|
||||||
|
k = at::reshape(k, {H, S, 1});
|
||||||
|
gemm_fp16_cublas_tensor(vx, vw, v);
|
||||||
|
v = at::reshape(v, {H, 1, S});
|
||||||
|
|
||||||
|
{
|
||||||
|
Tensor a = at::matmul(k, v);
|
||||||
|
|
||||||
|
// s1 = t_first * a + s
|
||||||
|
// s2 = a + t_decay * s
|
||||||
|
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay),
|
||||||
|
data_ptr<float>(a), data_ptr<float>(s),
|
||||||
|
static_cast<int32_t>(a.size(1) * a.size(2)),
|
||||||
|
data_ptr<float>(s1), data_ptr<float>(s2)},
|
||||||
|
a.numel());
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor out = at::matmul(r, s1);
|
||||||
|
out = at::flatten(out);
|
||||||
|
out = at::squeeze(at::group_norm(at::unsqueeze(out, 0), H, lx_w, lx_b), 0);
|
||||||
|
out = at::_cast_Half(out);
|
||||||
|
|
||||||
|
gemm_fp16_cublas_tensor(out, ow, x_plus_out);
|
||||||
|
x_plus_out += x;
|
||||||
|
return xx;
|
||||||
|
}
|
1
backend-python/rwkv_pip/beta/cuda/att_seq.cu
vendored
1
backend-python/rwkv_pip/beta/cuda/att_seq.cu
vendored
@ -8,7 +8,6 @@
|
|||||||
|
|
||||||
using torch::Tensor;
|
using torch::Tensor;
|
||||||
|
|
||||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
|
||||||
void gemm_fp16_cublas(const void *a, const void *b, void *c, int m,
|
void gemm_fp16_cublas(const void *a, const void *b, void *c, int m,
|
||||||
int n, int k, bool output_fp32);
|
int n, int k, bool output_fp32);
|
||||||
|
|
||||||
|
@ -70,11 +70,59 @@ void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m,
|
|||||||
cuda_c_data_type, cublas_ldc, compute_type, algo));
|
cuda_c_data_type, cublas_ldc, compute_type, algo));
|
||||||
}
|
}
|
||||||
|
|
||||||
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
/*
|
||||||
// comptiable with rwkv one mode, 1-D tensor * 2-D tensor
|
NOTE: blas gemm is column-major by default, but we need row-major output.
|
||||||
const int m = a.dense_dim() == 1 ? 1 : a.size(0);
|
The data of row-major, transposed matrix is exactly the same as the
|
||||||
const int n = b.size(1);
|
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
||||||
const int k = b.size(0);
|
*/
|
||||||
gemm_fp16_cublas(a.data_ptr(), b.data_ptr(), c.data_ptr(), m, n, k,
|
void gemm_fp16_cublas_tensor(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||||
c.dtype() == torch::kFloat32);
|
if (a.sizes().size() == 1) {
|
||||||
|
assert(b.sizes().size() == 2);
|
||||||
|
a = at::unsqueeze(a, 0);
|
||||||
|
}
|
||||||
|
const auto cuda_data_type = CUDA_R_16F;
|
||||||
|
const auto cuda_c_data_type =
|
||||||
|
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
||||||
|
const auto compute_type = CUDA_R_32F;
|
||||||
|
const float sp_alpha = 1.f;
|
||||||
|
// swap a and b, and use CUBLAS_OP_N. see the notes above
|
||||||
|
std::swap(a, b);
|
||||||
|
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
|
||||||
|
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
|
||||||
|
// m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
|
||||||
|
// negative axis is used because of the existence of batch matmul.
|
||||||
|
const int m = a.size(-1);
|
||||||
|
const int k = a.size(-2);
|
||||||
|
const int n = b.size(-2);
|
||||||
|
const int cublas_lda = m;
|
||||||
|
const int cublas_ldb = k;
|
||||||
|
const int cublas_ldc = m;
|
||||||
|
cublasHandle_t cublas_handle = get_cublas_handle();
|
||||||
|
|
||||||
|
#if CUDA_VERSION >= 11000
|
||||||
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||||
|
#else
|
||||||
|
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
||||||
|
#endif
|
||||||
|
const float sp_beta = 0.f;
|
||||||
|
if (a.sizes().size() == 2 && b.sizes().size() == 2) {
|
||||||
|
CUBLAS_CHECK(cublasGemmEx(
|
||||||
|
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
|
||||||
|
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
|
||||||
|
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
|
||||||
|
compute_type, algo));
|
||||||
|
} else {
|
||||||
|
// batch matmul
|
||||||
|
assert(a.sizes().size() == 3 && b.sizes().size() == 3);
|
||||||
|
|
||||||
|
const long long int cublas_stride_a = m * k;
|
||||||
|
const long long int cublas_stride_b = k * n;
|
||||||
|
const long long int cublas_stride_c = m * n;
|
||||||
|
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
|
||||||
|
cublas_handle, cublas_trans_a, cublas_trans_b, m,
|
||||||
|
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
|
||||||
|
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
|
||||||
|
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
|
||||||
|
a.size(0), compute_type, algo));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
20
backend-python/rwkv_pip/beta/cuda/wrapper.cpp
vendored
20
backend-python/rwkv_pip/beta/cuda/wrapper.cpp
vendored
@ -118,7 +118,9 @@ void mm8_one(int64_t N, int64_t M,
|
|||||||
|
|
||||||
using torch::Tensor;
|
using torch::Tensor;
|
||||||
|
|
||||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
#ifndef DISABLE_CUBLAS_GEMM
|
||||||
|
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
|
||||||
|
#endif
|
||||||
|
|
||||||
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||||
Tensor v_mix, Tensor r_mix, Tensor kw,
|
Tensor v_mix, Tensor r_mix, Tensor kw,
|
||||||
@ -134,6 +136,16 @@ Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
|||||||
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
|
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
|
||||||
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out);
|
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out);
|
||||||
|
|
||||||
|
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
|
||||||
|
Tensor lx_w, Tensor lx_b, Tensor k_mix, Tensor v_mix,
|
||||||
|
Tensor r_mix, Tensor kw,
|
||||||
|
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx,
|
||||||
|
Tensor rw,
|
||||||
|
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
|
||||||
|
/* imm */ Tensor k, Tensor t_decay, /* imm */ Tensor v,
|
||||||
|
/* imm */ Tensor r, /* imm */ Tensor s1,
|
||||||
|
/* out */ Tensor x_plus_out, /* out */ Tensor s2);
|
||||||
|
|
||||||
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||||
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||||||
/* imm */ Tensor buf,
|
/* imm */ Tensor buf,
|
||||||
@ -148,8 +160,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
||||||
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
||||||
m.def("mm8_one", &mm8_one, "mm8 one");
|
m.def("mm8_one", &mm8_one, "mm8 one");
|
||||||
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
|
m.def("gemm_fp16_cublas", &gemm_fp16_cublas_tensor, "gemv fp16 cublas");
|
||||||
m.def("att_one", &att_one, "att one");
|
m.def("att_one", &att_one, "att one");
|
||||||
|
m.def("att_one_v5", &att_one_v5, "att one v5");
|
||||||
m.def("att_seq", &att_seq, "att seq");
|
m.def("att_seq", &att_seq, "att seq");
|
||||||
m.def("ffn_seq", &ffn_seq, "ffn seq");
|
m.def("ffn_seq", &ffn_seq, "ffn seq");
|
||||||
m.def("ffn_one", &ffn_one, "ffn one");
|
m.def("ffn_one", &ffn_one, "ffn one");
|
||||||
@ -159,8 +172,9 @@ TORCH_LIBRARY(rwkv, m) {
|
|||||||
m.def("wkv_forward", wkv_forward);
|
m.def("wkv_forward", wkv_forward);
|
||||||
m.def("mm8_seq", mm8_seq);
|
m.def("mm8_seq", mm8_seq);
|
||||||
m.def("mm8_one", mm8_one);
|
m.def("mm8_one", mm8_one);
|
||||||
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
|
m.def("gemm_fp16_cublas", gemm_fp16_cublas_tensor);
|
||||||
m.def("att_one", att_one);
|
m.def("att_one", att_one);
|
||||||
|
m.def("att_one_v5", &att_one_v5);
|
||||||
m.def("att_seq", att_seq);
|
m.def("att_seq", att_seq);
|
||||||
m.def("ffn_seq", ffn_seq);
|
m.def("ffn_seq", ffn_seq);
|
||||||
m.def("ffn_one", ffn_one);
|
m.def("ffn_one", ffn_one);
|
||||||
|
359
backend-python/rwkv_pip/beta/model.py
vendored
359
backend-python/rwkv_pip/beta/model.py
vendored
@ -3,7 +3,7 @@
|
|||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import types, gc, os, time, re
|
import types, gc, os, time, re, platform
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
@ -91,6 +91,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
|
|||||||
f"{current_path}/cuda/att_one.cu",
|
f"{current_path}/cuda/att_one.cu",
|
||||||
f"{current_path}/cuda/att_seq.cu",
|
f"{current_path}/cuda/att_seq.cu",
|
||||||
f"{current_path}/cuda/ffn.cu",
|
f"{current_path}/cuda/ffn.cu",
|
||||||
|
f"{current_path}/cuda/att_one_v5.cu",
|
||||||
],
|
],
|
||||||
verbose=True,
|
verbose=True,
|
||||||
extra_cuda_cflags=[
|
extra_cuda_cflags=[
|
||||||
@ -149,26 +150,40 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
|
|||||||
torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y)
|
torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y)
|
||||||
return y.to(dtype=x.dtype)
|
return y.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
os.environ["RWKV_CUDA_ON"] = "0"
|
||||||
|
|
||||||
|
if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||||
|
|
||||||
@MyStatic
|
@MyStatic
|
||||||
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||||
if output_dtype is None:
|
if output_dtype is None:
|
||||||
output_dtype = a.dtype
|
output_dtype = a.dtype
|
||||||
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
|
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
|
||||||
assert len(b.shape) == 2
|
|
||||||
if len(a.shape) == 1:
|
if len(a.shape) == 1:
|
||||||
|
assert len(b.shape) == 2
|
||||||
c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device)
|
c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device)
|
||||||
a = a.unsqueeze(0)
|
a = a.unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
|
assert len(a.shape) == len(b.shape)
|
||||||
|
assert len(a.shape) == 2 or len(a.shape) == 3
|
||||||
|
# torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit
|
||||||
|
if len(a.shape) == 2:
|
||||||
c = torch.empty(
|
c = torch.empty(
|
||||||
(a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device
|
(a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
c = torch.empty(
|
||||||
|
(a.shape[0], a.shape[1], b.shape[-1]),
|
||||||
|
dtype=output_dtype,
|
||||||
|
device=a.device,
|
||||||
|
)
|
||||||
torch.ops.rwkv.gemm_fp16_cublas(a, b, c)
|
torch.ops.rwkv.gemm_fp16_cublas(a, b, c)
|
||||||
return c
|
return c
|
||||||
else:
|
else:
|
||||||
return (a @ b).to(output_dtype)
|
return (a @ b).to(output_dtype)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
os.environ["RWKV_CUDA_ON"] = "0"
|
|
||||||
|
|
||||||
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
|
||||||
if output_dtype is None:
|
if output_dtype is None:
|
||||||
@ -217,7 +232,7 @@ class RWKV(MyModule):
|
|||||||
) # load model to CPU first
|
) # load model to CPU first
|
||||||
# it is supported to load a pure meta-tensor state dict (e.g. for quick testing)
|
# it is supported to load a pure meta-tensor state dict (e.g. for quick testing)
|
||||||
for k, v in self.w.items():
|
for k, v in self.w.items():
|
||||||
if v.is_meta:
|
if isinstance(v, torch.Tensor) and v.is_meta:
|
||||||
# torch.zeros_like(v, device='cpu') doesn't produce an all-zero tensor
|
# torch.zeros_like(v, device='cpu') doesn't produce an all-zero tensor
|
||||||
# if v is a meta tensor
|
# if v is a meta tensor
|
||||||
self.w[k] = torch.zeros(v.shape, dtype=v.dtype, device="cpu")
|
self.w[k] = torch.zeros(v.shape, dtype=v.dtype, device="cpu")
|
||||||
@ -247,9 +262,14 @@ class RWKV(MyModule):
|
|||||||
args.n_embd = w["emb.weight"].shape[1]
|
args.n_embd = w["emb.weight"].shape[1]
|
||||||
args.n_layer = 0
|
args.n_layer = 0
|
||||||
keys = list(w.keys())
|
keys = list(w.keys())
|
||||||
|
self.version = 4
|
||||||
for x in keys:
|
for x in keys:
|
||||||
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
|
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
|
||||||
args.n_layer = max(args.n_layer, layer_id + 1)
|
args.n_layer = max(args.n_layer, layer_id + 1)
|
||||||
|
if "ln_x" in x:
|
||||||
|
self.version = 5
|
||||||
|
if self.version == 5 and "att.time_decay" in x:
|
||||||
|
args.n_head = w[x].shape[0]
|
||||||
|
|
||||||
####################### Compute strategy
|
####################### Compute strategy
|
||||||
|
|
||||||
@ -352,6 +372,20 @@ class RWKV(MyModule):
|
|||||||
del w["blocks.0.ln0.bias"]
|
del w["blocks.0.ln0.bias"]
|
||||||
|
|
||||||
print_need_newline = False
|
print_need_newline = False
|
||||||
|
|
||||||
|
REAL_TIME_FIRST = False
|
||||||
|
for x in list(w.keys()):
|
||||||
|
if ".time_faaaa" in x:
|
||||||
|
REAL_TIME_FIRST = True
|
||||||
|
if REAL_TIME_FIRST:
|
||||||
|
w = {
|
||||||
|
k.replace(".time_faaaa", ".time_first")
|
||||||
|
if ".time_faaaa" in k
|
||||||
|
else k: v
|
||||||
|
for k, v in w.items()
|
||||||
|
}
|
||||||
|
self.w = w
|
||||||
|
|
||||||
keys = list(w.keys())
|
keys = list(w.keys())
|
||||||
for x in keys:
|
for x in keys:
|
||||||
w[x].requires_grad = False
|
w[x].requires_grad = False
|
||||||
@ -382,8 +416,19 @@ class RWKV(MyModule):
|
|||||||
w[x] = w[x].t()
|
w[x] = w[x].t()
|
||||||
|
|
||||||
if ".time_decay" in x: # need fp32 for this
|
if ".time_decay" in x: # need fp32 for this
|
||||||
|
if self.version == 4:
|
||||||
w[x] = -torch.exp(w[x].float())
|
w[x] = -torch.exp(w[x].float())
|
||||||
|
elif self.version == 5:
|
||||||
|
w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1)
|
||||||
elif ".time_first" in x: # need fp32 for this
|
elif ".time_first" in x: # need fp32 for this
|
||||||
|
if self.version == 4:
|
||||||
|
w[x] = w[x].float()
|
||||||
|
elif self.version == 5:
|
||||||
|
if REAL_TIME_FIRST:
|
||||||
|
w[x] = w[x].float().reshape(-1, 1, 1)
|
||||||
|
else:
|
||||||
|
w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1)
|
||||||
|
elif ".ln_x" in x: # need fp32 for group_norm
|
||||||
w[x] = w[x].float()
|
w[x] = w[x].float()
|
||||||
else:
|
else:
|
||||||
if (len(w[x].shape) == 2) and ("emb" not in x):
|
if (len(w[x].shape) == 2) and ("emb" not in x):
|
||||||
@ -931,6 +976,147 @@ class RWKV(MyModule):
|
|||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def att_one_v5(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
sx,
|
||||||
|
s,
|
||||||
|
ln_w,
|
||||||
|
ln_b,
|
||||||
|
lx_w,
|
||||||
|
lx_b,
|
||||||
|
k_mix,
|
||||||
|
v_mix,
|
||||||
|
r_mix,
|
||||||
|
t_decay,
|
||||||
|
t_first,
|
||||||
|
kw,
|
||||||
|
vw,
|
||||||
|
rw,
|
||||||
|
ow,
|
||||||
|
kmx,
|
||||||
|
krx,
|
||||||
|
kmy,
|
||||||
|
kry,
|
||||||
|
vmx,
|
||||||
|
vrx,
|
||||||
|
vmy,
|
||||||
|
vry,
|
||||||
|
rmx,
|
||||||
|
rrx,
|
||||||
|
rmy,
|
||||||
|
rry,
|
||||||
|
omx,
|
||||||
|
orx,
|
||||||
|
omy,
|
||||||
|
ory,
|
||||||
|
):
|
||||||
|
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||||
|
kx = xx * k_mix + sx * (1 - k_mix)
|
||||||
|
vx = xx * v_mix + sx * (1 - v_mix)
|
||||||
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||||||
|
|
||||||
|
H = t_decay.shape[0]
|
||||||
|
S = x.shape[-1] // H
|
||||||
|
|
||||||
|
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
|
||||||
|
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
|
||||||
|
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
|
||||||
|
|
||||||
|
a = gemm(k, v)
|
||||||
|
out = r @ (t_first * a + s)
|
||||||
|
s = a + t_decay * s
|
||||||
|
|
||||||
|
out = out.flatten()
|
||||||
|
out = F.group_norm(
|
||||||
|
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
|
||||||
|
).squeeze(0)
|
||||||
|
out = out.to(dtype=x.dtype)
|
||||||
|
out = gemm(out, ow)
|
||||||
|
|
||||||
|
return x + out, xx, s
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def att_seq_v5(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
sx,
|
||||||
|
s,
|
||||||
|
ln_w,
|
||||||
|
ln_b,
|
||||||
|
lx_w,
|
||||||
|
lx_b,
|
||||||
|
k_mix,
|
||||||
|
v_mix,
|
||||||
|
r_mix,
|
||||||
|
t_decay,
|
||||||
|
t_first,
|
||||||
|
kw,
|
||||||
|
vw,
|
||||||
|
rw,
|
||||||
|
ow,
|
||||||
|
kmx,
|
||||||
|
krx,
|
||||||
|
kmy,
|
||||||
|
kry,
|
||||||
|
vmx,
|
||||||
|
vrx,
|
||||||
|
vmy,
|
||||||
|
vry,
|
||||||
|
rmx,
|
||||||
|
rrx,
|
||||||
|
rmy,
|
||||||
|
rry,
|
||||||
|
omx,
|
||||||
|
orx,
|
||||||
|
omy,
|
||||||
|
ory,
|
||||||
|
):
|
||||||
|
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||||
|
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
|
||||||
|
kx = xx * k_mix + sx * (1 - k_mix)
|
||||||
|
vx = xx * v_mix + sx * (1 - v_mix)
|
||||||
|
rx = xx * r_mix + sx * (1 - r_mix)
|
||||||
|
|
||||||
|
H = t_decay.shape[0]
|
||||||
|
S = x.shape[-1] // H
|
||||||
|
T = x.shape[0]
|
||||||
|
|
||||||
|
w = t_decay.reshape(-1, 1)
|
||||||
|
u = t_first.reshape(-1, 1)
|
||||||
|
ws = w.pow(T).reshape(H, 1, 1)
|
||||||
|
ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1)
|
||||||
|
w = w.repeat(1, T).pow(ind)
|
||||||
|
wk = w.reshape(H, 1, T)
|
||||||
|
wb = wk.transpose(-2, -1).flip(1)
|
||||||
|
w = torch.cat([w[:, 1:], u], dim=1)
|
||||||
|
w = F.pad(w, (0, T))
|
||||||
|
w = torch.tile(w, [T])
|
||||||
|
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
|
||||||
|
w = w[:, :, T - 1 :].reshape(H, T, T)
|
||||||
|
|
||||||
|
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||||
|
k = (
|
||||||
|
gemm(kx, kw, output_dtype=torch.float32)
|
||||||
|
.view(T, H, S)
|
||||||
|
.transpose(0, 1)
|
||||||
|
.transpose(-2, -1)
|
||||||
|
)
|
||||||
|
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||||
|
|
||||||
|
out = ((r @ k) * w) @ v + (r @ s) * wb
|
||||||
|
s = ws * s + (k * wk) @ v
|
||||||
|
|
||||||
|
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
|
||||||
|
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||||
|
out = out.to(dtype=x.dtype)
|
||||||
|
out = gemm(out, ow)
|
||||||
|
|
||||||
|
return x + out, xx[-1, :], s
|
||||||
|
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
if os.environ["RWKV_CUDA_ON"] == "1":
|
if os.environ["RWKV_CUDA_ON"] == "1":
|
||||||
|
|
||||||
@MyFunction
|
@MyFunction
|
||||||
@ -1140,7 +1326,7 @@ class RWKV(MyModule):
|
|||||||
xx = torch.ops.rwkv.ffn_seq(
|
xx = torch.ops.rwkv.ffn_seq(
|
||||||
x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, buf, x_plus_out
|
x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, buf, x_plus_out
|
||||||
)
|
)
|
||||||
return x_plus_out, xx[-1:]
|
return x_plus_out, xx[-1, :]
|
||||||
|
|
||||||
@MyFunction
|
@MyFunction
|
||||||
def cuda_att_one_fp16(
|
def cuda_att_one_fp16(
|
||||||
@ -1220,6 +1406,86 @@ class RWKV(MyModule):
|
|||||||
)
|
)
|
||||||
return x_plus_out_t, xx, t1_t, t2_t, p_t
|
return x_plus_out_t, xx, t1_t, t2_t, p_t
|
||||||
|
|
||||||
|
@MyFunction
|
||||||
|
def cuda_att_one_v5_fp16(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
sx,
|
||||||
|
s,
|
||||||
|
ln_w,
|
||||||
|
ln_b,
|
||||||
|
lx_w,
|
||||||
|
lx_b,
|
||||||
|
k_mix,
|
||||||
|
v_mix,
|
||||||
|
r_mix,
|
||||||
|
t_decay,
|
||||||
|
t_first,
|
||||||
|
kw,
|
||||||
|
vw,
|
||||||
|
rw,
|
||||||
|
ow,
|
||||||
|
kmx,
|
||||||
|
krx,
|
||||||
|
kmy,
|
||||||
|
kry,
|
||||||
|
vmx,
|
||||||
|
vrx,
|
||||||
|
vmy,
|
||||||
|
vry,
|
||||||
|
rmx,
|
||||||
|
rrx,
|
||||||
|
rmy,
|
||||||
|
rry,
|
||||||
|
omx,
|
||||||
|
orx,
|
||||||
|
omy,
|
||||||
|
ory,
|
||||||
|
):
|
||||||
|
kx = torch.empty_like(x)
|
||||||
|
vx = torch.empty_like(x)
|
||||||
|
rx = torch.empty_like(x)
|
||||||
|
|
||||||
|
H = t_decay.shape[0]
|
||||||
|
S = x.shape[-1] // H
|
||||||
|
|
||||||
|
r = torch.empty((H * S,), dtype=torch.float32, device=x.device)
|
||||||
|
k = torch.empty((H * S,), dtype=torch.float32, device=x.device)
|
||||||
|
v = torch.empty((H * S,), dtype=torch.float32, device=x.device)
|
||||||
|
s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
|
||||||
|
s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
|
||||||
|
x_plus_out = torch.empty_like(x)
|
||||||
|
|
||||||
|
xx = torch.ops.rwkv.att_one_v5(
|
||||||
|
x,
|
||||||
|
sx,
|
||||||
|
s,
|
||||||
|
ln_w,
|
||||||
|
ln_b,
|
||||||
|
lx_w,
|
||||||
|
lx_b,
|
||||||
|
k_mix,
|
||||||
|
v_mix,
|
||||||
|
r_mix,
|
||||||
|
kw,
|
||||||
|
kx,
|
||||||
|
vw,
|
||||||
|
vx,
|
||||||
|
rw,
|
||||||
|
rx,
|
||||||
|
ow,
|
||||||
|
t_first,
|
||||||
|
k,
|
||||||
|
t_decay,
|
||||||
|
v,
|
||||||
|
r,
|
||||||
|
s1,
|
||||||
|
x_plus_out,
|
||||||
|
s2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return x_plus_out, xx, s2
|
||||||
|
|
||||||
@MyFunction
|
@MyFunction
|
||||||
def cuda_ffn_one_fp16(
|
def cuda_ffn_one_fp16(
|
||||||
self,
|
self,
|
||||||
@ -1265,6 +1531,7 @@ class RWKV(MyModule):
|
|||||||
args = self.args
|
args = self.args
|
||||||
|
|
||||||
if state == None:
|
if state == None:
|
||||||
|
if self.version == 4:
|
||||||
state = [None] * args.n_layer * 5
|
state = [None] * args.n_layer * 5
|
||||||
for i in range(
|
for i in range(
|
||||||
args.n_layer
|
args.n_layer
|
||||||
@ -1276,10 +1543,16 @@ class RWKV(MyModule):
|
|||||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||||
).contiguous()
|
).contiguous()
|
||||||
state[i * 5 + 1] = torch.zeros(
|
state[i * 5 + 1] = torch.zeros(
|
||||||
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
|
args.n_embd,
|
||||||
|
dtype=torch.float,
|
||||||
|
requires_grad=False,
|
||||||
|
device=dev,
|
||||||
).contiguous()
|
).contiguous()
|
||||||
state[i * 5 + 2] = torch.zeros(
|
state[i * 5 + 2] = torch.zeros(
|
||||||
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
|
args.n_embd,
|
||||||
|
dtype=torch.float,
|
||||||
|
requires_grad=False,
|
||||||
|
device=dev,
|
||||||
).contiguous()
|
).contiguous()
|
||||||
state[i * 5 + 3] = (
|
state[i * 5 + 3] = (
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
@ -1293,6 +1566,28 @@ class RWKV(MyModule):
|
|||||||
state[i * 5 + 4] = torch.zeros(
|
state[i * 5 + 4] = torch.zeros(
|
||||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||||
).contiguous()
|
).contiguous()
|
||||||
|
elif self.version == 5:
|
||||||
|
state = [None] * args.n_layer * 3
|
||||||
|
for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx
|
||||||
|
dd = self.strategy[i]
|
||||||
|
dev = dd.device
|
||||||
|
atype = dd.atype
|
||||||
|
state[i * 3 + 0] = torch.zeros(
|
||||||
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||||
|
).contiguous()
|
||||||
|
state[i * 3 + 1] = torch.zeros(
|
||||||
|
(
|
||||||
|
args.n_head,
|
||||||
|
args.n_embd // args.n_head,
|
||||||
|
args.n_embd // args.n_head,
|
||||||
|
),
|
||||||
|
dtype=torch.float,
|
||||||
|
requires_grad=False,
|
||||||
|
device=dev,
|
||||||
|
).contiguous()
|
||||||
|
state[i * 3 + 2] = torch.zeros(
|
||||||
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
seq_mode = len(tokens) > 1
|
seq_mode = len(tokens) > 1
|
||||||
|
|
||||||
@ -1317,9 +1612,13 @@ class RWKV(MyModule):
|
|||||||
ATT = self.cuda_att_seq_i8
|
ATT = self.cuda_att_seq_i8
|
||||||
else:
|
else:
|
||||||
ATT = self.cuda_att_seq_naive
|
ATT = self.cuda_att_seq_naive
|
||||||
|
if self.version == 5:
|
||||||
|
ATT = self.att_seq_v5
|
||||||
else:
|
else:
|
||||||
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
|
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
|
||||||
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
|
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
|
||||||
|
if self.version == 5:
|
||||||
|
ATT = self.att_one_v5
|
||||||
if (
|
if (
|
||||||
"cuda" in str(dev)
|
"cuda" in str(dev)
|
||||||
and os.environ["RWKV_CUDA_ON"] == "1"
|
and os.environ["RWKV_CUDA_ON"] == "1"
|
||||||
@ -1327,6 +1626,8 @@ class RWKV(MyModule):
|
|||||||
):
|
):
|
||||||
ATT = self.cuda_att_one_fp16
|
ATT = self.cuda_att_one_fp16
|
||||||
FFN = self.cuda_ffn_one_fp16
|
FFN = self.cuda_ffn_one_fp16
|
||||||
|
if self.version == 5:
|
||||||
|
ATT = self.cuda_att_one_v5_fp16
|
||||||
|
|
||||||
x = x.to(dtype=atype, device=dev)
|
x = x.to(dtype=atype, device=dev)
|
||||||
|
|
||||||
@ -1355,6 +1656,7 @@ class RWKV(MyModule):
|
|||||||
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
|
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
|
||||||
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
|
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
|
||||||
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
|
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
|
||||||
|
if self.version == 4:
|
||||||
(
|
(
|
||||||
x,
|
x,
|
||||||
state[i * 5 + 0],
|
state[i * 5 + 0],
|
||||||
@ -1395,6 +1697,41 @@ class RWKV(MyModule):
|
|||||||
omy,
|
omy,
|
||||||
ory,
|
ory,
|
||||||
)
|
)
|
||||||
|
elif self.version == 5:
|
||||||
|
x, state[i * 3 + 0], state[i * 3 + 1] = ATT(
|
||||||
|
x,
|
||||||
|
state[i * 3 + 0],
|
||||||
|
state[i * 3 + 1],
|
||||||
|
w[f"{bbb}ln1.weight"],
|
||||||
|
w[f"{bbb}ln1.bias"],
|
||||||
|
w[f"{att}ln_x.weight"],
|
||||||
|
w[f"{att}ln_x.bias"],
|
||||||
|
w[f"{att}time_mix_k"],
|
||||||
|
w[f"{att}time_mix_v"],
|
||||||
|
w[f"{att}time_mix_r"],
|
||||||
|
w[f"{att}time_decay"],
|
||||||
|
w[f"{att}time_first"],
|
||||||
|
kw,
|
||||||
|
vw,
|
||||||
|
rw,
|
||||||
|
ow,
|
||||||
|
kmx,
|
||||||
|
krx,
|
||||||
|
kmy,
|
||||||
|
kry,
|
||||||
|
vmx,
|
||||||
|
vrx,
|
||||||
|
vmy,
|
||||||
|
vry,
|
||||||
|
rmx,
|
||||||
|
rrx,
|
||||||
|
rmy,
|
||||||
|
rry,
|
||||||
|
omx,
|
||||||
|
orx,
|
||||||
|
omy,
|
||||||
|
ory,
|
||||||
|
)
|
||||||
if dd.stream:
|
if dd.stream:
|
||||||
del kw, vw, rw, ow
|
del kw, vw, rw, ow
|
||||||
|
|
||||||
@ -1417,9 +1754,13 @@ class RWKV(MyModule):
|
|||||||
rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x
|
rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x
|
||||||
rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x
|
rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x
|
||||||
rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x
|
rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x
|
||||||
x, state[i * 5 + 4] = FFN(
|
if self.version == 4:
|
||||||
|
offset = i * 5 + 4
|
||||||
|
elif self.version == 5:
|
||||||
|
offset = i * 3 + 2
|
||||||
|
x, state[offset] = FFN(
|
||||||
x,
|
x,
|
||||||
state[i * 5 + 4],
|
state[offset],
|
||||||
w[f"{bbb}ln2.weight"],
|
w[f"{bbb}ln2.weight"],
|
||||||
w[f"{bbb}ln2.bias"],
|
w[f"{bbb}ln2.bias"],
|
||||||
w[f"{ffn}time_mix_k"],
|
w[f"{ffn}time_mix_k"],
|
||||||
|
Loading…
Reference in New Issue
Block a user