add rwkv-cuda-beta support (faster)
This commit is contained in:
parent
da68926e9c
commit
8a13bd3c1e
@ -10,7 +10,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (a *App) StartServer(python string, port int, host string) (string, error) {
|
||||
func (a *App) StartServer(python string, port int, host string, rwkvBeta bool) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
python, err = GetPython()
|
||||
@ -18,7 +18,12 @@ func (a *App) StartServer(python string, port int, host string) (string, error)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return Cmd(python, "./backend-python/main.py", strconv.Itoa(port), host)
|
||||
args := []string{python, "./backend-python/main.py"}
|
||||
if rwkvBeta {
|
||||
args = append(args, "--rwkv-beta")
|
||||
}
|
||||
args = append(args, "--port", strconv.Itoa(port), "--host", host)
|
||||
return Cmd(args...)
|
||||
}
|
||||
|
||||
func (a *App) ConvertModel(python string, modelPath string, strategy string, outPath string) (string, error) {
|
||||
|
@ -1,5 +1,6 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
Args = "args"
|
||||
Model = "model"
|
||||
Model_Status = "model_status"
|
||||
Model_Config = "model_config"
|
||||
|
@ -1,5 +1,11 @@
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from typing import Sequence
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
@ -34,6 +40,11 @@ app.include_router(state_cache.router)
|
||||
@app.on_event("startup")
|
||||
def init():
|
||||
global_var.init()
|
||||
cmd_params = os.environ["RWKV_RUNNER_PARAMS"]
|
||||
global_var.set(
|
||||
global_var.Args, get_args(cmd_params.split(" ") if cmd_params else None)
|
||||
)
|
||||
|
||||
state_cache.init()
|
||||
|
||||
set_torch()
|
||||
@ -56,9 +67,34 @@ def exit():
|
||||
parent.kill()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
port=8000 if len(sys.argv) < 2 else int(sys.argv[1]),
|
||||
host="127.0.0.1" if len(sys.argv) < 3 else sys.argv[2],
|
||||
def get_args(args: Union[Sequence[str], None] = None):
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group(title="server arguments")
|
||||
group.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="port to run the server on (default: 8000)",
|
||||
)
|
||||
group.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="host to run the server on (default: 127.0.0.1)",
|
||||
)
|
||||
group = parser.add_argument_group(title="mode arguments")
|
||||
group.add_argument(
|
||||
"--rwkv-beta",
|
||||
action="store_true",
|
||||
help="whether to use rwkv-beta (default: False)",
|
||||
)
|
||||
args = parser.parse_args(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
os.environ["RWKV_RUNNER_PARAMS"] = " ".join(sys.argv[1:])
|
||||
print("--- %s seconds ---" % (time.time() - start_time))
|
||||
uvicorn.run("main:app", port=args.port, host=args.host, workers=1)
|
||||
|
124
backend-python/rwkv_pip/beta/cuda/att_one.cu
vendored
Normal file
124
backend-python/rwkv_pip/beta/cuda/att_one.cu
vendored
Normal file
@ -0,0 +1,124 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "element_wise.h"
|
||||
#include "util.h"
|
||||
|
||||
// Equivalent Python code:
|
||||
// ww = t_first + k
|
||||
// p = torch.maximum(pp, ww)
|
||||
// e1 = torch.exp(pp - p)
|
||||
// e2 = torch.exp(ww - p)
|
||||
// wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
|
||||
// ww = t_decay + pp
|
||||
// p = torch.maximum(ww, k)
|
||||
// e1 = torch.exp(ww - p)
|
||||
// e2 = torch.exp(k - p)
|
||||
// t1 = e1 * aa + e2 * v
|
||||
// t2 = e1 * bb + e2
|
||||
// r = r * wkv
|
||||
// return t1, t2, p, r
|
||||
struct WkvForwardOne {
|
||||
const float *t_first;
|
||||
const float *k;
|
||||
const float *pp;
|
||||
const float *aa;
|
||||
const float *bb;
|
||||
const float *t_decay;
|
||||
const float *v;
|
||||
/* out */ float *t1;
|
||||
/* out */ float *t2;
|
||||
/* out */ float *p;
|
||||
/* in & out */ half *r;
|
||||
|
||||
__device__ void operator()(int i) const {
|
||||
float ww = t_first[i] + k[i];
|
||||
float pp_ = pp[i];
|
||||
float p_ = (pp_ > ww) ? pp_ : ww;
|
||||
float e1 = expf(pp_ - p_);
|
||||
float e2 = expf(ww - p_);
|
||||
float aa_ = aa[i];
|
||||
float bb_ = bb[i];
|
||||
float v_ = v[i];
|
||||
r[i] = __hmul(r[i], __float2half(((e1 * aa_ + e2 * v_) / (e1 * bb_ + e2))));
|
||||
ww = t_decay[i] + pp_;
|
||||
float k_ = k[i];
|
||||
p_ = (ww > k_) ? ww : k_;
|
||||
e1 = expf(ww - p_);
|
||||
e2 = expf(k_ - p_);
|
||||
t1[i] = e1 * aa_ + e2 * v_;
|
||||
t2[i] = e1 * bb_ + e2;
|
||||
p[i] = p_;
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
Equivalent Python code:
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
*/
|
||||
|
||||
struct Mix {
|
||||
const half *xx;
|
||||
const half *sx;
|
||||
const half *k_mix;
|
||||
const half *v_mix;
|
||||
const half *r_mix;
|
||||
/* out */ half *kx;
|
||||
/* out */ half *vx;
|
||||
/* out */ half *rx;
|
||||
|
||||
__device__ void operator()(int i) const {
|
||||
half xx_ = xx[i];
|
||||
half sx_ = sx[i];
|
||||
half k_mix_ = k_mix[i];
|
||||
half v_mix_ = v_mix[i];
|
||||
half r_mix_ = r_mix[i];
|
||||
kx[i] = __hadd(__hmul(xx_, k_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||||
vx[i] = __hadd(__hmul(xx_, v_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
|
||||
rx[i] = __hadd(__hmul(xx_, r_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||||
}
|
||||
};
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
Tensor v_mix, Tensor r_mix, Tensor kw,
|
||||
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw,
|
||||
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
|
||||
/* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb,
|
||||
Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r,
|
||||
/* out */ Tensor x_plus_out, /* out */ Tensor t1,
|
||||
/* out */ Tensor t2, /* out */ Tensor p) {
|
||||
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||||
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
|
||||
data_ptr<half>(k_mix), data_ptr<half>(v_mix),
|
||||
data_ptr<half>(r_mix), data_ptr<half>(kx),
|
||||
data_ptr<half>(vx), data_ptr<half>(rx)},
|
||||
x.numel());
|
||||
|
||||
gemm_fp16_cublas(kx, kw, k);
|
||||
gemm_fp16_cublas(vx, vw, v);
|
||||
gemm_fp16_cublas(rx, rw, r);
|
||||
at::sigmoid_(r);
|
||||
|
||||
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
|
||||
data_ptr<float>(pp), data_ptr<float>(aa),
|
||||
data_ptr<float>(bb), data_ptr<float>(t_decay),
|
||||
data_ptr<float>(v), data_ptr<float>(t1),
|
||||
data_ptr<float>(t2), data_ptr<float>(p),
|
||||
data_ptr<half>(r)},
|
||||
x.numel());
|
||||
|
||||
gemm_fp16_cublas(r, ow, x_plus_out);
|
||||
x_plus_out += x;
|
||||
return xx;
|
||||
}
|
179
backend-python/rwkv_pip/beta/cuda/att_seq.cu
vendored
Normal file
179
backend-python/rwkv_pip/beta/cuda/att_seq.cu
vendored
Normal file
@ -0,0 +1,179 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "util.h"
|
||||
#include "element_wise.h"
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
||||
void gemm_fp16_cublas(const void *a, const void *b, void *c, int m,
|
||||
int n, int k, bool output_fp32);
|
||||
|
||||
// based on `kernel_wkv_forward`, fusing more operations
|
||||
__global__ void kernel_wkv_forward_new(
|
||||
const int B, const int T, const int C, const float *__restrict__ const _w,
|
||||
const float *__restrict__ const _u, const float *__restrict__ const _k,
|
||||
const float *__restrict__ const _v, const half *__restrict__ const r,
|
||||
half *__restrict__ const _y, float *__restrict__ const _aa,
|
||||
float *__restrict__ const _bb, float *__restrict__ const _pp) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
const int _state_offset = _b * C + _c;
|
||||
|
||||
float u = _u[_c];
|
||||
float w = _w[_c];
|
||||
const float *__restrict__ const k = _k + _offset;
|
||||
const float *__restrict__ const v = _v + _offset;
|
||||
half *__restrict__ const y = _y + _offset;
|
||||
|
||||
float aa = _aa[_state_offset];
|
||||
float bb = _bb[_state_offset];
|
||||
float pp = _pp[_state_offset];
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = k[ii];
|
||||
const float vv = v[ii];
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
y[ii] = __float2half((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
_aa[_state_offset] = aa;
|
||||
_bb[_state_offset] = bb;
|
||||
_pp[_state_offset] = pp;
|
||||
}
|
||||
|
||||
void cuda_wkv_forward_new(int B, int T, int C, float *w, float *u, float *k,
|
||||
float *v, half *r, half *y, float *aa, float *bb,
|
||||
float *pp) {
|
||||
dim3 threadsPerBlock(min(C, 32));
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_wkv_forward_new<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, r,
|
||||
y, aa, bb, pp);
|
||||
}
|
||||
|
||||
__global__ void _att_mix(const half *xx, const half *sx, const half *k_mix,
|
||||
const half *v_mix, const half *r_mix,
|
||||
const int outer_size, const int inner_size, half *kx,
|
||||
half *vx, half *rx) {
|
||||
for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size;
|
||||
idx2 += blockDim.x * gridDim.x) {
|
||||
half k_mix_ = k_mix[idx2];
|
||||
half v_mix_ = v_mix[idx2];
|
||||
half r_mix_ = r_mix[idx2];
|
||||
for (int row = 0; row < outer_size; ++row) {
|
||||
int idx1 = row * inner_size + idx2;
|
||||
half xx_ = xx[idx1];
|
||||
half sx_ = sx[idx1];
|
||||
kx[idx1] = __hadd(__hmul(xx_, k_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||||
vx[idx1] = __hadd(__hmul(xx_, v_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
|
||||
rx[idx1] = __hadd(__hmul(xx_, r_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void att_mix(const half *xx, const half *sx, const half *k_mix,
|
||||
const half *v_mix, const half *r_mix, const int outer_size,
|
||||
const int inner_size, half *kx, half *vx, half *rx) {
|
||||
// 256 is good enough on most GPUs
|
||||
const int32_t BLOCK_SIZE = 256;
|
||||
assert(inner_size % BLOCK_SIZE == 0);
|
||||
_att_mix<<<inner_size / BLOCK_SIZE, BLOCK_SIZE>>>(
|
||||
xx, sx, k_mix, v_mix, r_mix, outer_size, inner_size, kx, vx, rx);
|
||||
}
|
||||
|
||||
struct InplaceSigmoid {
|
||||
__device__ __forceinline__ half operator()(int i) const {
|
||||
ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i]))));
|
||||
}
|
||||
half *ptr;
|
||||
};
|
||||
|
||||
struct InplaceMul {
|
||||
__device__ __forceinline__ half operator()(int i) const {
|
||||
y[i] = __hmul(x[i], y[i]);
|
||||
}
|
||||
half *y;
|
||||
half *x;
|
||||
};
|
||||
|
||||
/*
|
||||
Equivalent Python code:
|
||||
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(gemm(rx, rw))
|
||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
||||
|
||||
T = x.shape[0]
|
||||
for t in range(T):
|
||||
kk = k[t]
|
||||
vv = v[t]
|
||||
ww = t_first + kk
|
||||
p = torch.maximum(pp, ww)
|
||||
e1 = torch.exp(pp - p)
|
||||
e2 = torch.exp(ww - p)
|
||||
sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype)
|
||||
ww = t_decay + pp
|
||||
p = torch.maximum(ww, kk)
|
||||
e1 = torch.exp(ww - p)
|
||||
e2 = torch.exp(kk - p)
|
||||
aa = e1 * aa + e2 * vv
|
||||
bb = e1 * bb + e2
|
||||
pp = p
|
||||
out = gemm(r * sx, ow)
|
||||
return x + out, xx[-1,:], aa, bb, pp
|
||||
*/
|
||||
Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||||
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
|
||||
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out) {
|
||||
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||||
sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0);
|
||||
char* buf_ptr = (char*)buf.data_ptr();
|
||||
half* kx = (half*)buf_ptr;
|
||||
half* vx = kx + x.numel();
|
||||
half* rx = vx + x.numel();
|
||||
half* wkv_y = rx + x.numel();
|
||||
att_mix(data_ptr<half>(xx), data_ptr<half>(sx), data_ptr<half>(k_mix),
|
||||
data_ptr<half>(v_mix), data_ptr<half>(r_mix), xx.size(0), xx.size(1),
|
||||
kx, vx, rx);
|
||||
float* k = reinterpret_cast<float*>(wkv_y + x.numel());
|
||||
float* v = k + x.size(0) * kw.size(1);
|
||||
half* r = reinterpret_cast<half*>(v + x.size(0) * vw.size(1));
|
||||
|
||||
gemm_fp16_cublas(kx, kw.data_ptr(), k, x.size(0), kw.size(1), kw.size(0), true);
|
||||
gemm_fp16_cublas(vx, vw.data_ptr(), v, x.size(0), vw.size(1), vw.size(0), true);
|
||||
gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), rw.size(0), false);
|
||||
element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1));
|
||||
cuda_wkv_forward_new(1, x.size(0), x.size(1), data_ptr<float>(t_decay),
|
||||
data_ptr<float>(t_first), k, v, r,
|
||||
wkv_y, data_ptr<float>(aa),
|
||||
data_ptr<float>(bb), data_ptr<float>(pp));
|
||||
element_wise(InplaceMul{wkv_y, r}, x.numel());
|
||||
gemm_fp16_cublas(wkv_y, ow.data_ptr(), x_plus_out.data_ptr(), x.size(0), ow.size(1), ow.size(0), false);
|
||||
x_plus_out += x;
|
||||
return xx;
|
||||
}
|
21
backend-python/rwkv_pip/beta/cuda/element_wise.h
vendored
Normal file
21
backend-python/rwkv_pip/beta/cuda/element_wise.h
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
template <typename Func> __global__ void _element_wise(Func func, int n) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
|
||||
i += blockDim.x * gridDim.x) {
|
||||
func(i);
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: packed data type (e.g. float4) is a overkill for current sizes
|
||||
// (4096 in 7B model and 768 in 0.1B model),
|
||||
// and is not faster than the plain float version.
|
||||
template <typename Func>
|
||||
void element_wise(Func func, int n) {
|
||||
// 256 is good enough on most GPUs
|
||||
const int32_t BLOCK_SIZE = 256;
|
||||
assert(n % BLOCK_SIZE == 0);
|
||||
_element_wise<<<n / BLOCK_SIZE, BLOCK_SIZE>>>(func, n);
|
||||
}
|
165
backend-python/rwkv_pip/beta/cuda/ffn.cu
vendored
Normal file
165
backend-python/rwkv_pip/beta/cuda/ffn.cu
vendored
Normal file
@ -0,0 +1,165 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "element_wise.h"
|
||||
#include "util.h"
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m,
|
||||
int ori_n, int ori_k, bool output_fp32);
|
||||
|
||||
__global__ void _ffn_seq_mix(const half *xx, const half *sx, const half *k_mix,
|
||||
const half *r_mix, const int outer_size,
|
||||
const int inner_size, half *kx, half *rx) {
|
||||
for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size;
|
||||
idx2 += blockDim.x * gridDim.x) {
|
||||
half k_mix_ = k_mix[idx2];
|
||||
half r_mix_ = r_mix[idx2];
|
||||
for (int row = 0; row < outer_size; ++row) {
|
||||
int idx1 = row * inner_size + idx2;
|
||||
half xx_ = xx[idx1];
|
||||
half sx_ = sx[idx1];
|
||||
kx[idx1] = __hadd(__hmul(xx_, k_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||||
rx[idx1] = __hadd(__hmul(xx_, r_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ffn_seq_mix(const half *xx, const half *sx, const half *k_mix,
|
||||
const half *r_mix, const int outer_size, const int inner_size,
|
||||
half *kx, half *rx) {
|
||||
// 256 is good enough on most GPUs
|
||||
const int32_t BLOCK_SIZE = 256;
|
||||
assert(inner_size % BLOCK_SIZE == 0);
|
||||
_ffn_seq_mix<<<inner_size / BLOCK_SIZE, BLOCK_SIZE>>>(
|
||||
xx, sx, k_mix, r_mix, outer_size, inner_size, kx, rx);
|
||||
}
|
||||
|
||||
struct InplaceSigmoid {
|
||||
__device__ __forceinline__ void operator()(int i) const {
|
||||
ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i]))));
|
||||
}
|
||||
half *ptr;
|
||||
};
|
||||
|
||||
struct InplaceReLUAndSquare {
|
||||
__device__ __forceinline__ void operator()(int i) const {
|
||||
// __hmax is not defined in old cuda
|
||||
if (__hgt(ptr[i], __float2half(0))) {
|
||||
ptr[i] = __hmul(ptr[i], ptr[i]);
|
||||
} else {
|
||||
ptr[i] = __float2half(0);
|
||||
}
|
||||
}
|
||||
half *ptr;
|
||||
};
|
||||
|
||||
struct InplaceFma {
|
||||
__device__ __forceinline__ void operator()(int i) const {
|
||||
a[i] = __hfma(a[i], b[i], c[i]);
|
||||
}
|
||||
half *a;
|
||||
const half *b;
|
||||
const half *c;
|
||||
};
|
||||
|
||||
/*
|
||||
Equivalent Python code:
|
||||
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(gemm(rx, rw))
|
||||
vx = torch.square(torch.relu(gemm(kx, kw)))
|
||||
out = r * gemm(vx, vw)
|
||||
return x + out, xx[-1,:]
|
||||
*/
|
||||
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||||
/* imm */ Tensor buf,
|
||||
/* out */ Tensor x_plus_out) {
|
||||
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||||
sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0);
|
||||
char *buf_ptr = (char *)buf.data_ptr();
|
||||
half *kx = (half *)buf_ptr;
|
||||
half *rx = kx + x.numel();
|
||||
half *vx = rx + x.numel();
|
||||
half *r = vx + x.size(0) * kw.size(1);
|
||||
ffn_seq_mix(data_ptr<half>(xx), data_ptr<half>(sx), data_ptr<half>(k_mix),
|
||||
data_ptr<half>(r_mix), xx.size(0), xx.size(1), kx, rx);
|
||||
|
||||
gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), x.size(1),
|
||||
false);
|
||||
element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1));
|
||||
gemm_fp16_cublas(kx, kw.data_ptr(), vx, x.size(0), kw.size(1), x.size(1),
|
||||
false);
|
||||
element_wise(InplaceReLUAndSquare{vx}, x.size(0) * kw.size(1));
|
||||
gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), x.size(0),
|
||||
vw.size(1), vw.size(0), false);
|
||||
element_wise(InplaceFma{data_ptr<half>(x_plus_out), r, data_ptr<half>(x)},
|
||||
x_plus_out.numel());
|
||||
return xx;
|
||||
}
|
||||
|
||||
struct FfnOneMix {
|
||||
__device__ __forceinline__ void operator()(int idx) {
|
||||
half k_mix_ = k_mix[idx];
|
||||
half r_mix_ = r_mix[idx];
|
||||
half xx_ = xx[idx];
|
||||
half sx_ = sx[idx];
|
||||
kx[idx] = __hadd(__hmul(xx_, k_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
|
||||
rx[idx] = __hadd(__hmul(xx_, r_mix_),
|
||||
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
|
||||
}
|
||||
half *k_mix;
|
||||
half *r_mix;
|
||||
half *xx;
|
||||
half *sx;
|
||||
half *kx;
|
||||
half *rx;
|
||||
};
|
||||
|
||||
/*
|
||||
Equivalent Python code:
|
||||
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
|
||||
r = torch.sigmoid(gemm(rx, rw))
|
||||
vx = torch.square(torch.relu(gemm(kx, kw)))
|
||||
out = r * gemm(vx, vw)
|
||||
return x + out, xx
|
||||
*/
|
||||
Tensor ffn_one(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||||
/* imm */ Tensor buf,
|
||||
/* out */ Tensor x_plus_out) {
|
||||
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
|
||||
char *buf_ptr = (char *)buf.data_ptr();
|
||||
half *kx = (half *)buf_ptr;
|
||||
half *rx = kx + x.numel();
|
||||
half *vx = rx + x.numel();
|
||||
half *r = vx + x.size(0) * kw.size(1);
|
||||
element_wise(FfnOneMix{data_ptr<half>(k_mix), data_ptr<half>(r_mix),
|
||||
data_ptr<half>(xx), data_ptr<half>(sx), kx, rx},
|
||||
x.numel());
|
||||
// vector * matrix, so m = 1
|
||||
gemm_fp16_cublas(rx, rw.data_ptr(), r, 1, rw.size(1), rw.size(0), false);
|
||||
element_wise(InplaceSigmoid{r}, rw.size(1));
|
||||
gemm_fp16_cublas(kx, kw.data_ptr(), vx, 1, kw.size(1), kw.size(0), false);
|
||||
element_wise(InplaceReLUAndSquare{vx}, kw.size(1));
|
||||
gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), 1, vw.size(1),
|
||||
vw.size(0), false);
|
||||
element_wise(InplaceFma{data_ptr<half>(x_plus_out), r, data_ptr<half>(x)},
|
||||
x_plus_out.numel());
|
||||
return xx;
|
||||
}
|
80
backend-python/rwkv_pip/beta/cuda/gemm_fp16_cublas.cpp
vendored
Normal file
80
backend-python/rwkv_pip/beta/cuda/gemm_fp16_cublas.cpp
vendored
Normal file
@ -0,0 +1,80 @@
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define CUBLAS_CHECK(condition) \
|
||||
for (cublasStatus_t _cublas_check_status = (condition); \
|
||||
_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
|
||||
throw std::runtime_error("cuBLAS error " + \
|
||||
std::to_string(_cublas_check_status) + " at " + \
|
||||
std::to_string(__LINE__));
|
||||
|
||||
#define CUDA_CHECK(condition) \
|
||||
for (cudaError_t _cuda_check_status = (condition); \
|
||||
_cuda_check_status != cudaSuccess;) \
|
||||
throw std::runtime_error( \
|
||||
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
||||
" at " + std::to_string(__LINE__));
|
||||
|
||||
cublasHandle_t get_cublas_handle() {
|
||||
static cublasHandle_t cublas_handle = []() {
|
||||
cublasHandle_t handle = nullptr;
|
||||
CUBLAS_CHECK(cublasCreate(&handle));
|
||||
#if CUDA_VERSION < 11000
|
||||
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
|
||||
#else
|
||||
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
#endif // CUDA_VERSION < 11000
|
||||
return handle;
|
||||
}();
|
||||
return cublas_handle;
|
||||
}
|
||||
|
||||
/*
|
||||
NOTE: blas gemm is column-major by default, but we need row-major output.
|
||||
The data of row-major, transposed matrix is exactly the same as the
|
||||
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
||||
*/
|
||||
void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m,
|
||||
int ori_n, int ori_k, bool output_fp32) {
|
||||
const auto cuda_data_type = CUDA_R_16F;
|
||||
const auto cuda_c_data_type = output_fp32 ? CUDA_R_32F : CUDA_R_16F;
|
||||
const auto compute_type = CUDA_R_32F;
|
||||
const float sp_alpha = 1.f;
|
||||
// use CUBLAS_OP_N. see the notes above
|
||||
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
|
||||
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
|
||||
// m = (B^T).size(0) = B.size(1) = n;
|
||||
const int cublas_m = ori_n;
|
||||
const int cublas_k = ori_k;
|
||||
// comptiable with rwkv one mode, where 1-D tensor * 2-D tensor
|
||||
// const int n = a.dense_dim() == 1 ? 1 : a.size(0);
|
||||
const int cublas_n = ori_m;
|
||||
const int cublas_lda = cublas_m;
|
||||
const int cublas_ldb = cublas_k;
|
||||
const int cublas_ldc = cublas_m;
|
||||
cublasHandle_t cublas_handle = get_cublas_handle();
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
||||
#else
|
||||
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
||||
#endif
|
||||
const float sp_beta = 0.f;
|
||||
CUBLAS_CHECK(cublasGemmEx(
|
||||
cublas_handle, cublas_trans_a, cublas_trans_b, cublas_m, cublas_n,
|
||||
cublas_k, &sp_alpha, b, cuda_data_type, cublas_lda,
|
||||
a, cuda_data_type, cublas_ldb, &sp_beta, c,
|
||||
cuda_c_data_type, cublas_ldc, compute_type, algo));
|
||||
}
|
||||
|
||||
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
||||
// comptiable with rwkv one mode, 1-D tensor * 2-D tensor
|
||||
const int m = a.dense_dim() == 1 ? 1 : a.size(0);
|
||||
const int n = b.size(1);
|
||||
const int k = b.size(0);
|
||||
gemm_fp16_cublas(a.data_ptr(), b.data_ptr(), c.data_ptr(), m, n, k,
|
||||
c.dtype() == torch::kFloat32);
|
||||
}
|
246
backend-python/rwkv_pip/beta/cuda/operators.cu
vendored
Normal file
246
backend-python/rwkv_pip/beta/cuda/operators.cu
vendored
Normal file
@ -0,0 +1,246 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
#include "ATen/ATen.h"
|
||||
#include <cuda_fp16.h>
|
||||
#define MIN_VALUE (-1e38)
|
||||
typedef at::Half fp16;
|
||||
__half *cast(fp16 *ptr) {
|
||||
return reinterpret_cast<__half *>(ptr);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_wkv_forward(const int B, const int T, const int C,
|
||||
const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||
F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
const int _state_offset = _b * C + _c;
|
||||
|
||||
float u = _u[_c];
|
||||
float w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
F *__restrict__ const y = _y + _offset;
|
||||
|
||||
float aa = _aa[_state_offset];
|
||||
float bb = _bb[_state_offset];
|
||||
float pp = _pp[_state_offset];
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
_aa[_state_offset] = aa;
|
||||
_bb[_state_offset] = bb;
|
||||
_pp[_state_offset] = pp;
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
|
||||
dim3 threadsPerBlock( min(C, 32) );
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
|
||||
}
|
||||
|
||||
template void cuda_wkv_forward<fp16>(
|
||||
int B, int T, int C,
|
||||
float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
|
||||
float *aa, float *bb, float *pp);
|
||||
template void cuda_wkv_forward<float>(
|
||||
int B, int T, int C,
|
||||
float *w, float *u, float *k, float *v, float *y,
|
||||
float *aa, float *bb, float *pp);
|
||||
|
||||
__global__ void kernel_mm_seq_fp32i8(
|
||||
const int B, const int N, const int M,
|
||||
const float *__restrict__ const x, const int x_stride,
|
||||
const uint8_t *__restrict__ const w, const int w_stride,
|
||||
const float *__restrict__ const mx,
|
||||
const float *__restrict__ const rx,
|
||||
const float *__restrict__ const my,
|
||||
const float *__restrict__ const ry,
|
||||
float *__restrict__ const y, const int y_stride) {
|
||||
|
||||
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (i < B && k < M) {
|
||||
float y_local = 0;
|
||||
for (int j = 0; j < N; ++j) {
|
||||
y_local += x[i * x_stride + j] * (
|
||||
(float(w[j * w_stride + k]) + 0.5f)
|
||||
* rx[k] * ry[j] + mx[k] + my[j]
|
||||
);
|
||||
}
|
||||
y[i * y_stride + k] = y_local;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void cuda_mm8_seq(int B, int N, int M,
|
||||
F *x, int x_stride,
|
||||
uint8_t *w, int w_stride,
|
||||
F *mx, F *rx,
|
||||
F *my, F *ry,
|
||||
F *y, int y_stride);
|
||||
|
||||
template <>
|
||||
void cuda_mm8_seq<float>(int B, int N, int M,
|
||||
float *x, int x_stride,
|
||||
uint8_t *w, int w_stride,
|
||||
float *mx, float *rx,
|
||||
float *my, float *ry,
|
||||
float *y, int y_stride) {
|
||||
dim3 blockSize(1, 128);
|
||||
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
||||
kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
|
||||
B, N, M, x, x_stride, w, w_stride,
|
||||
mx, rx, my, ry, y, y_stride);
|
||||
}
|
||||
|
||||
__global__ void kernel_mm_seq_fp16i8(
|
||||
const int B, const int N, const int M,
|
||||
const __half *__restrict__ const x, const int x_stride,
|
||||
const uint8_t *__restrict__ const w, const int w_stride,
|
||||
const __half *__restrict__ const mx,
|
||||
const __half *__restrict__ const rx,
|
||||
const __half *__restrict__ const my,
|
||||
const __half *__restrict__ const ry,
|
||||
__half *__restrict__ const y, const int y_stride) {
|
||||
|
||||
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if (i < B && k < M) {
|
||||
float y_local = 0;
|
||||
for (int j = 0; j < N; ++j) {
|
||||
y_local += __half2float(x[i * x_stride + j]) * (
|
||||
(float(w[j * w_stride + k]) + 0.5f)
|
||||
* __half2float(rx[k]) * __half2float(ry[j])
|
||||
+ __half2float(mx[k]) + __half2float(my[j])
|
||||
);
|
||||
}
|
||||
y[i * y_stride + k] = __float2half(y_local);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void cuda_mm8_seq<fp16>(int B, int N, int M,
|
||||
fp16 *x, int x_stride,
|
||||
uint8_t *w, int w_stride,
|
||||
fp16 *mx, fp16 *rx,
|
||||
fp16 *my, fp16 *ry,
|
||||
fp16 *y, int y_stride) {
|
||||
dim3 blockSize(1, 128);
|
||||
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
||||
kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
|
||||
B, N, M, cast(x), x_stride, w, w_stride,
|
||||
cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
|
||||
}
|
||||
|
||||
#define MM8_ONE_JSPLIT 24
|
||||
#define MM8_ONE_TILE 1024
|
||||
|
||||
__global__ void kernel_mm_one_fp32i8(
|
||||
const int N, const int M,
|
||||
const float *__restrict__ const x,
|
||||
const uint8_t *__restrict__ const w, const int w_stride,
|
||||
const float *__restrict__ const mx,
|
||||
const float *__restrict__ const rx,
|
||||
const float *__restrict__ const my,
|
||||
const float *__restrict__ const ry,
|
||||
float *__restrict__ const y) {
|
||||
|
||||
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
||||
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
||||
|
||||
if (k < M) {
|
||||
float y_local = 0;
|
||||
for (int j = j0; j < j1; ++j) {
|
||||
y_local += x[j] * (
|
||||
(float(w[j * w_stride + k]) + 0.5f)
|
||||
* rx[k] * ry[j] + mx[k] + my[j]
|
||||
);
|
||||
}
|
||||
atomicAdd(&y[k], y_local);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void cuda_mm8_one(int N, int M,
|
||||
F *x,
|
||||
uint8_t *w, int w_stride,
|
||||
F *mx, F *rx,
|
||||
F *my, F *ry,
|
||||
float *y);
|
||||
|
||||
template <>
|
||||
void cuda_mm8_one<float>(int N, int M,
|
||||
float *x,
|
||||
uint8_t *w, int w_stride,
|
||||
float *mx, float *rx,
|
||||
float *my, float *ry,
|
||||
float *y) {
|
||||
dim3 blockSize(1, MM8_ONE_TILE);
|
||||
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
||||
kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
|
||||
N, M, x, w, w_stride,
|
||||
mx, rx, my, ry, y);
|
||||
}
|
||||
|
||||
__global__ void kernel_mm_one_fp16i8(
|
||||
const int N, const int M,
|
||||
const __half *__restrict__ const x,
|
||||
const uint8_t *__restrict__ const w, const int w_stride,
|
||||
const __half *__restrict__ const mx,
|
||||
const __half *__restrict__ const rx,
|
||||
const __half *__restrict__ const my,
|
||||
const __half *__restrict__ const ry,
|
||||
float *__restrict__ const y) {
|
||||
|
||||
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
||||
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
||||
|
||||
if (k < M) {
|
||||
float y_local = 0;
|
||||
for (int j = j0; j < j1; ++j) {
|
||||
y_local += __half2float(x[j]) * (
|
||||
(float(w[j * w_stride + k]) + 0.5f)
|
||||
* __half2float(rx[k]) * __half2float(ry[j])
|
||||
+ __half2float(mx[k]) + __half2float(my[j])
|
||||
);
|
||||
}
|
||||
atomicAdd(&y[k], y_local);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void cuda_mm8_one<fp16>(int N, int M,
|
||||
fp16 *x,
|
||||
uint8_t *w, int w_stride,
|
||||
fp16 *mx, fp16 *rx,
|
||||
fp16 *my, fp16 *ry,
|
||||
float *y) {
|
||||
dim3 blockSize(1, MM8_ONE_TILE);
|
||||
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
||||
kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
|
||||
N, M, cast(x), w, w_stride,
|
||||
cast(mx), cast(rx), cast(my), cast(ry), y);
|
||||
}
|
7
backend-python/rwkv_pip/beta/cuda/util.h
vendored
Normal file
7
backend-python/rwkv_pip/beta/cuda/util.h
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
#include "ATen/ATen.h"
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
template <typename T> T *data_ptr(torch::Tensor x) { return x.data_ptr<T>(); }
|
||||
template <> inline half *data_ptr(torch::Tensor x) {
|
||||
return reinterpret_cast<half *>(x.data_ptr<at::Half>());
|
||||
}
|
167
backend-python/rwkv_pip/beta/cuda/wrapper.cpp
vendored
Normal file
167
backend-python/rwkv_pip/beta/cuda/wrapper.cpp
vendored
Normal file
@ -0,0 +1,167 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/ATen.h"
|
||||
#include <iostream>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
typedef at::Half fp16;
|
||||
|
||||
template <typename F>
|
||||
void cuda_wkv_forward(int B, int T, int C,
|
||||
float *w, float *u, F *k, F *v, F *y,
|
||||
float *aa, float *bb, float *pp);
|
||||
template <typename F>
|
||||
void cuda_mm8_seq(int B, int N, int M,
|
||||
F *x, int x_stride,
|
||||
uint8_t *w, int w_stride,
|
||||
F *mx, F *rx,
|
||||
F *my, F *ry,
|
||||
F *y, int y_stride);
|
||||
template <typename F>
|
||||
void cuda_mm8_one(int N, int M,
|
||||
F *x,
|
||||
uint8_t *w, int w_stride,
|
||||
F *mx, F *rx,
|
||||
F *my, F *ry,
|
||||
float *y);
|
||||
|
||||
void wkv_forward(int64_t B, int64_t T, int64_t C,
|
||||
torch::Tensor &w, torch::Tensor &u,
|
||||
torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
||||
torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
||||
switch (k.scalar_type()) {
|
||||
case c10::ScalarType::Half:
|
||||
cuda_wkv_forward(B, T, C,
|
||||
w.data_ptr<float>(), u.data_ptr<float>(),
|
||||
k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
|
||||
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
||||
break;
|
||||
case c10::ScalarType::Float:
|
||||
cuda_wkv_forward(B, T, C,
|
||||
w.data_ptr<float>(), u.data_ptr<float>(),
|
||||
k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
|
||||
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
||||
break;
|
||||
default:
|
||||
assert(false && "Only FP16 and FP32 are currently supported");
|
||||
}
|
||||
}
|
||||
|
||||
void mm8_seq(int64_t B, int64_t N, int64_t M,
|
||||
torch::Tensor &x, torch::Tensor &w,
|
||||
torch::Tensor &mx, torch::Tensor &rx,
|
||||
torch::Tensor &my, torch::Tensor &ry,
|
||||
torch::Tensor &y) {
|
||||
assert(x.stride(1) == 1);
|
||||
assert(w.stride(1) == 1);
|
||||
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
||||
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
||||
assert(y.stride(1) == 1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
||||
switch (x.scalar_type()) {
|
||||
case c10::ScalarType::Half:
|
||||
cuda_mm8_seq(
|
||||
B, N, M,
|
||||
x.data_ptr<fp16>(), x.stride(0),
|
||||
w.data_ptr<uint8_t>(), w.stride(0),
|
||||
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
||||
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
||||
y.data_ptr<fp16>(), y.stride(0));
|
||||
break;
|
||||
case c10::ScalarType::Float:
|
||||
cuda_mm8_seq(
|
||||
B, N, M,
|
||||
x.data_ptr<float>(), x.stride(0),
|
||||
w.data_ptr<uint8_t>(), w.stride(0),
|
||||
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
||||
my.data_ptr<float>(), ry.data_ptr<float>(),
|
||||
y.data_ptr<float>(), y.stride(0));
|
||||
break;
|
||||
default:
|
||||
assert(false && "Only FP16 and FP32 are currently supported");
|
||||
}
|
||||
}
|
||||
void mm8_one(int64_t N, int64_t M,
|
||||
torch::Tensor &x, torch::Tensor &w,
|
||||
torch::Tensor &mx, torch::Tensor &rx,
|
||||
torch::Tensor &my, torch::Tensor &ry,
|
||||
torch::Tensor &y) {
|
||||
assert(x.stride(0) == 1);
|
||||
assert(w.stride(1) == 1);
|
||||
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
||||
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
||||
assert(y.stride(0) == 1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
||||
switch (x.scalar_type()) {
|
||||
case c10::ScalarType::Half:
|
||||
cuda_mm8_one(
|
||||
N, M,
|
||||
x.data_ptr<fp16>(),
|
||||
w.data_ptr<uint8_t>(), w.stride(0),
|
||||
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
||||
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
||||
y.data_ptr<float>());
|
||||
break;
|
||||
case c10::ScalarType::Float:
|
||||
cuda_mm8_one(
|
||||
N, M,
|
||||
x.data_ptr<float>(),
|
||||
w.data_ptr<uint8_t>(), w.stride(0),
|
||||
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
||||
my.data_ptr<float>(), ry.data_ptr<float>(),
|
||||
y.data_ptr<float>());
|
||||
break;
|
||||
default:
|
||||
assert(false && "Only FP16 and FP32 are currently supported");
|
||||
}
|
||||
}
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
||||
|
||||
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
|
||||
Tensor v_mix, Tensor r_mix, Tensor kw,
|
||||
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw,
|
||||
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
|
||||
/* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb,
|
||||
Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r,
|
||||
/* out */ Tensor x_plus_out, /* out */ Tensor t1,
|
||||
/* out */ Tensor t2, /* out */ Tensor p);
|
||||
|
||||
Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||||
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
|
||||
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out);
|
||||
|
||||
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||||
/* imm */ Tensor buf,
|
||||
/* out */ Tensor x_plus_out);
|
||||
|
||||
Tensor ffn_one(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
|
||||
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
|
||||
/* imm */ Tensor buf,
|
||||
/* out */ Tensor x_plus_out);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
||||
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
||||
m.def("mm8_one", &mm8_one, "mm8 one");
|
||||
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
|
||||
m.def("att_one", &att_one, "att one");
|
||||
m.def("att_seq", &att_seq, "att seq");
|
||||
m.def("ffn_seq", &ffn_seq, "ffn seq");
|
||||
m.def("ffn_one", &ffn_one, "ffn one");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(rwkv, m) {
|
||||
m.def("wkv_forward", wkv_forward);
|
||||
m.def("mm8_seq", mm8_seq);
|
||||
m.def("mm8_one", mm8_one);
|
||||
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
|
||||
m.def("att_one", att_one);
|
||||
m.def("att_seq", att_seq);
|
||||
m.def("ffn_seq", ffn_seq);
|
||||
m.def("ffn_one", ffn_one);
|
||||
}
|
1479
backend-python/rwkv_pip/beta/model.py
vendored
Normal file
1479
backend-python/rwkv_pip/beta/model.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@ -10,6 +10,7 @@ from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
from routes import state_cache
|
||||
import global_var
|
||||
|
||||
|
||||
END_OF_TEXT = 0
|
||||
@ -27,7 +28,17 @@ class RWKVType(Enum):
|
||||
|
||||
class AbstractRWKV(ABC):
|
||||
def __init__(self, model: str, strategy: str, tokens_path: str):
|
||||
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
|
||||
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
||||
|
||||
# dynamic import to make RWKV_CUDA_ON work
|
||||
if rwkv_beta:
|
||||
from rwkv_pip.beta.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
else:
|
||||
from rwkv.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
from rwkv_pip.utils import PIPELINE
|
||||
|
||||
filename, _ = os.path.splitext(os.path.basename(model))
|
||||
@ -221,7 +232,7 @@ class AbstractRWKV(ABC):
|
||||
return state[0].tolist(), token_len
|
||||
|
||||
def generate(
|
||||
self, prompt: str, stop: Union[str, List[str]] = None
|
||||
self, prompt: str, stop: Union[str, List[str], None] = None
|
||||
) -> Iterable[Tuple[str, str, int, int]]:
|
||||
quick_log(None, None, "Generation Prompt:\n" + prompt)
|
||||
cache = None
|
||||
@ -438,8 +449,10 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
{bot} usually gives {user} kind, helpful and informative advices.\n
|
||||
"""
|
||||
if self.rwkv_type == RWKVType.Raven
|
||||
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
|
||||
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
|
||||
else (
|
||||
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
|
||||
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
|
||||
)
|
||||
)
|
||||
logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
|
||||
try:
|
||||
|
@ -128,6 +128,7 @@
|
||||
"Chinese Kongfu": "中国武術",
|
||||
"Allow external access to the API (service must be restarted)": "APIへの外部アクセスを許可する (サービスを再起動する必要があります)",
|
||||
"Custom": "カスタム",
|
||||
"CUDA (Beta, Faster)": "CUDA (ベータ、高速)",
|
||||
"Reset All Configs": "すべての設定をリセット",
|
||||
"Cancel": "キャンセル",
|
||||
"Confirm": "確認",
|
||||
|
@ -128,6 +128,7 @@
|
||||
"Chinese Kongfu": "情境冒险",
|
||||
"Allow external access to the API (service must be restarted)": "允许外部访问API (必须重启服务)",
|
||||
"Custom": "自定义",
|
||||
"CUDA (Beta, Faster)": "CUDA (Beta, 更快)",
|
||||
"Reset All Configs": "重置所有配置",
|
||||
"Cancel": "取消",
|
||||
"Confirm": "确认",
|
||||
|
@ -85,7 +85,9 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
|
||||
await exit(1000).catch(() => {
|
||||
});
|
||||
StartServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1').catch((e) => {
|
||||
StartServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
|
||||
modelConfig.modelParameters.device === 'CUDA-Beta'
|
||||
).catch((e) => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
toast(`${t('Error')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||
@ -118,7 +120,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
|
||||
const strategy = getStrategy(modelConfig);
|
||||
let customCudaFile = '';
|
||||
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'Custom')
|
||||
if ((modelConfig.modelParameters.device.includes('CUDA') || modelConfig.modelParameters.device === 'Custom')
|
||||
&& modelConfig.modelParameters.useCustomCuda && !strategy.includes('fp32')) {
|
||||
if (commonStore.platform === 'windows') {
|
||||
customCudaFile = getSupportedCustomCudaFile();
|
||||
|
@ -30,7 +30,7 @@ export type ApiParameters = {
|
||||
frequencyPenalty: number;
|
||||
}
|
||||
|
||||
export type Device = 'CPU' | 'CUDA' | 'MPS' | 'Custom';
|
||||
export type Device = 'CPU' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
|
||||
export type Precision = 'fp16' | 'int8' | 'fp32';
|
||||
|
||||
export type ModelParameters = {
|
||||
@ -284,6 +284,8 @@ export const Configs: FC = observer(() => {
|
||||
<Option value="CPU">CPU</Option>
|
||||
{commonStore.platform === 'darwin' && <Option value="MPS">MPS</Option>}
|
||||
<Option value="CUDA">CUDA</Option>
|
||||
<Option value="CUDA-Beta">{t('CUDA (Beta, Faster)')!}</Option>
|
||||
<Option value="WebGPU" disabled>WebGPU</Option>
|
||||
<Option value="Custom">{t('Custom')!}</Option>
|
||||
</Dropdown>
|
||||
} />
|
||||
@ -308,12 +310,12 @@ export const Configs: FC = observer(() => {
|
||||
} />
|
||||
}
|
||||
{
|
||||
selectedConfig.modelParameters.device == 'CUDA' &&
|
||||
selectedConfig.modelParameters.device.includes('CUDA') &&
|
||||
<Labeled label={t('Current Strategy')}
|
||||
content={<Text> {getStrategy(selectedConfig)} </Text>} />
|
||||
}
|
||||
{
|
||||
selectedConfig.modelParameters.device == 'CUDA' &&
|
||||
selectedConfig.modelParameters.device.includes('CUDA') &&
|
||||
<Labeled label={t('Stored Layers')}
|
||||
desc={t('Number of the neural network layers loaded into VRAM, the more you load, the faster the speed, but it consumes more VRAM. (If your VRAM is not enough, it will fail to load)')}
|
||||
content={
|
||||
@ -327,7 +329,7 @@ export const Configs: FC = observer(() => {
|
||||
} />
|
||||
}
|
||||
{
|
||||
selectedConfig.modelParameters.device == 'CUDA' && <div />
|
||||
selectedConfig.modelParameters.device.includes('CUDA') && <div />
|
||||
}
|
||||
{
|
||||
displayStrategyImg &&
|
||||
|
@ -177,6 +177,7 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
|
||||
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
|
||||
break;
|
||||
case 'CUDA':
|
||||
case 'CUDA-Beta':
|
||||
if (avoidOverflow)
|
||||
strategy = 'cuda fp32 *1 -> ';
|
||||
strategy += 'cuda ';
|
||||
|
2
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
2
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
@ -46,7 +46,7 @@ export function RestartApp():Promise<void>;
|
||||
|
||||
export function SaveJson(arg1:string,arg2:any):Promise<void>;
|
||||
|
||||
export function StartServer(arg1:string,arg2:number,arg3:string):Promise<string>;
|
||||
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean):Promise<string>;
|
||||
|
||||
export function UpdateApp(arg1:string):Promise<boolean>;
|
||||
|
||||
|
4
frontend/wailsjs/go/backend_golang/App.js
generated
4
frontend/wailsjs/go/backend_golang/App.js
generated
@ -90,8 +90,8 @@ export function SaveJson(arg1, arg2) {
|
||||
return window['go']['backend_golang']['App']['SaveJson'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function StartServer(arg1, arg2, arg3) {
|
||||
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3);
|
||||
export function StartServer(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function UpdateApp(arg1) {
|
||||
|
Loading…
Reference in New Issue
Block a user