add rwkv-cuda-beta support (faster)

This commit is contained in:
josc146 2023-08-14 22:07:15 +08:00
parent da68926e9c
commit 8a13bd3c1e
20 changed files with 2550 additions and 20 deletions

View File

@ -10,7 +10,7 @@ import (
"strings"
)
func (a *App) StartServer(python string, port int, host string) (string, error) {
func (a *App) StartServer(python string, port int, host string, rwkvBeta bool) (string, error) {
var err error
if python == "" {
python, err = GetPython()
@ -18,7 +18,12 @@ func (a *App) StartServer(python string, port int, host string) (string, error)
if err != nil {
return "", err
}
return Cmd(python, "./backend-python/main.py", strconv.Itoa(port), host)
args := []string{python, "./backend-python/main.py"}
if rwkvBeta {
args = append(args, "--rwkv-beta")
}
args = append(args, "--port", strconv.Itoa(port), "--host", host)
return Cmd(args...)
}
func (a *App) ConvertModel(python string, modelPath string, strategy string, outPath string) (string, error) {

View File

@ -1,5 +1,6 @@
from enum import Enum, auto
Args = "args"
Model = "model"
Model_Status = "model_status"
Model_Config = "model_config"

View File

@ -1,5 +1,11 @@
import time
start_time = time.time()
import os
import sys
import argparse
from typing import Sequence
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
@ -34,6 +40,11 @@ app.include_router(state_cache.router)
@app.on_event("startup")
def init():
global_var.init()
cmd_params = os.environ["RWKV_RUNNER_PARAMS"]
global_var.set(
global_var.Args, get_args(cmd_params.split(" ") if cmd_params else None)
)
state_cache.init()
set_torch()
@ -56,9 +67,34 @@ def exit():
parent.kill()
if __name__ == "__main__":
uvicorn.run(
"main:app",
port=8000 if len(sys.argv) < 2 else int(sys.argv[1]),
host="127.0.0.1" if len(sys.argv) < 3 else sys.argv[2],
def get_args(args: Union[Sequence[str], None] = None):
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title="server arguments")
group.add_argument(
"--port",
type=int,
default=8000,
help="port to run the server on (default: 8000)",
)
group.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="host to run the server on (default: 127.0.0.1)",
)
group = parser.add_argument_group(title="mode arguments")
group.add_argument(
"--rwkv-beta",
action="store_true",
help="whether to use rwkv-beta (default: False)",
)
args = parser.parse_args(args)
return args
if __name__ == "__main__":
args = get_args()
os.environ["RWKV_RUNNER_PARAMS"] = " ".join(sys.argv[1:])
print("--- %s seconds ---" % (time.time() - start_time))
uvicorn.run("main:app", port=args.port, host=args.host, workers=1)

View File

@ -0,0 +1,124 @@
#include "ATen/ATen.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "element_wise.h"
#include "util.h"
// Equivalent Python code:
// ww = t_first + k
// p = torch.maximum(pp, ww)
// e1 = torch.exp(pp - p)
// e2 = torch.exp(ww - p)
// wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
// ww = t_decay + pp
// p = torch.maximum(ww, k)
// e1 = torch.exp(ww - p)
// e2 = torch.exp(k - p)
// t1 = e1 * aa + e2 * v
// t2 = e1 * bb + e2
// r = r * wkv
// return t1, t2, p, r
struct WkvForwardOne {
const float *t_first;
const float *k;
const float *pp;
const float *aa;
const float *bb;
const float *t_decay;
const float *v;
/* out */ float *t1;
/* out */ float *t2;
/* out */ float *p;
/* in & out */ half *r;
__device__ void operator()(int i) const {
float ww = t_first[i] + k[i];
float pp_ = pp[i];
float p_ = (pp_ > ww) ? pp_ : ww;
float e1 = expf(pp_ - p_);
float e2 = expf(ww - p_);
float aa_ = aa[i];
float bb_ = bb[i];
float v_ = v[i];
r[i] = __hmul(r[i], __float2half(((e1 * aa_ + e2 * v_) / (e1 * bb_ + e2))));
ww = t_decay[i] + pp_;
float k_ = k[i];
p_ = (ww > k_) ? ww : k_;
e1 = expf(ww - p_);
e2 = expf(k_ - p_);
t1[i] = e1 * aa_ + e2 * v_;
t2[i] = e1 * bb_ + e2;
p[i] = p_;
}
};
/*
Equivalent Python code:
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
*/
struct Mix {
const half *xx;
const half *sx;
const half *k_mix;
const half *v_mix;
const half *r_mix;
/* out */ half *kx;
/* out */ half *vx;
/* out */ half *rx;
__device__ void operator()(int i) const {
half xx_ = xx[i];
half sx_ = sx[i];
half k_mix_ = k_mix[i];
half v_mix_ = v_mix[i];
half r_mix_ = r_mix[i];
kx[i] = __hadd(__hmul(xx_, k_mix_),
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
vx[i] = __hadd(__hmul(xx_, v_mix_),
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
rx[i] = __hadd(__hmul(xx_, r_mix_),
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
}
};
using torch::Tensor;
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
Tensor v_mix, Tensor r_mix, Tensor kw,
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw,
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
/* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb,
Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r,
/* out */ Tensor x_plus_out, /* out */ Tensor t1,
/* out */ Tensor t2, /* out */ Tensor p) {
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
data_ptr<half>(k_mix), data_ptr<half>(v_mix),
data_ptr<half>(r_mix), data_ptr<half>(kx),
data_ptr<half>(vx), data_ptr<half>(rx)},
x.numel());
gemm_fp16_cublas(kx, kw, k);
gemm_fp16_cublas(vx, vw, v);
gemm_fp16_cublas(rx, rw, r);
at::sigmoid_(r);
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
data_ptr<float>(pp), data_ptr<float>(aa),
data_ptr<float>(bb), data_ptr<float>(t_decay),
data_ptr<float>(v), data_ptr<float>(t1),
data_ptr<float>(t2), data_ptr<float>(p),
data_ptr<half>(r)},
x.numel());
gemm_fp16_cublas(r, ow, x_plus_out);
x_plus_out += x;
return xx;
}

View File

@ -0,0 +1,179 @@
#include "ATen/ATen.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "util.h"
#include "element_wise.h"
using torch::Tensor;
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
void gemm_fp16_cublas(const void *a, const void *b, void *c, int m,
int n, int k, bool output_fp32);
// based on `kernel_wkv_forward`, fusing more operations
__global__ void kernel_wkv_forward_new(
const int B, const int T, const int C, const float *__restrict__ const _w,
const float *__restrict__ const _u, const float *__restrict__ const _k,
const float *__restrict__ const _v, const half *__restrict__ const r,
half *__restrict__ const _y, float *__restrict__ const _aa,
float *__restrict__ const _bb, float *__restrict__ const _pp) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
const int _state_offset = _b * C + _c;
float u = _u[_c];
float w = _w[_c];
const float *__restrict__ const k = _k + _offset;
const float *__restrict__ const v = _v + _offset;
half *__restrict__ const y = _y + _offset;
float aa = _aa[_state_offset];
float bb = _bb[_state_offset];
float pp = _pp[_state_offset];
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = k[ii];
const float vv = v[ii];
float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
y[ii] = __float2half((e1 * aa + e2 * vv) / (e1 * bb + e2));
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
_aa[_state_offset] = aa;
_bb[_state_offset] = bb;
_pp[_state_offset] = pp;
}
void cuda_wkv_forward_new(int B, int T, int C, float *w, float *u, float *k,
float *v, half *r, half *y, float *aa, float *bb,
float *pp) {
dim3 threadsPerBlock(min(C, 32));
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_wkv_forward_new<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, r,
y, aa, bb, pp);
}
__global__ void _att_mix(const half *xx, const half *sx, const half *k_mix,
const half *v_mix, const half *r_mix,
const int outer_size, const int inner_size, half *kx,
half *vx, half *rx) {
for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size;
idx2 += blockDim.x * gridDim.x) {
half k_mix_ = k_mix[idx2];
half v_mix_ = v_mix[idx2];
half r_mix_ = r_mix[idx2];
for (int row = 0; row < outer_size; ++row) {
int idx1 = row * inner_size + idx2;
half xx_ = xx[idx1];
half sx_ = sx[idx1];
kx[idx1] = __hadd(__hmul(xx_, k_mix_),
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
vx[idx1] = __hadd(__hmul(xx_, v_mix_),
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
rx[idx1] = __hadd(__hmul(xx_, r_mix_),
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
}
}
}
void att_mix(const half *xx, const half *sx, const half *k_mix,
const half *v_mix, const half *r_mix, const int outer_size,
const int inner_size, half *kx, half *vx, half *rx) {
// 256 is good enough on most GPUs
const int32_t BLOCK_SIZE = 256;
assert(inner_size % BLOCK_SIZE == 0);
_att_mix<<<inner_size / BLOCK_SIZE, BLOCK_SIZE>>>(
xx, sx, k_mix, v_mix, r_mix, outer_size, inner_size, kx, vx, rx);
}
struct InplaceSigmoid {
__device__ __forceinline__ half operator()(int i) const {
ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i]))));
}
half *ptr;
};
struct InplaceMul {
__device__ __forceinline__ half operator()(int i) const {
y[i] = __hmul(x[i], y[i]);
}
half *y;
half *x;
};
/*
Equivalent Python code:
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw))
k = gemm(kx, kw, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32)
T = x.shape[0]
for t in range(T):
kk = k[t]
vv = v[t]
ww = t_first + kk
p = torch.maximum(pp, ww)
e1 = torch.exp(pp - p)
e2 = torch.exp(ww - p)
sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype)
ww = t_decay + pp
p = torch.maximum(ww, kk)
e1 = torch.exp(ww - p)
e2 = torch.exp(kk - p)
aa = e1 * aa + e2 * vv
bb = e1 * bb + e2
pp = p
out = gemm(r * sx, ow)
return x + out, xx[-1,:], aa, bb, pp
*/
Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out) {
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0);
char* buf_ptr = (char*)buf.data_ptr();
half* kx = (half*)buf_ptr;
half* vx = kx + x.numel();
half* rx = vx + x.numel();
half* wkv_y = rx + x.numel();
att_mix(data_ptr<half>(xx), data_ptr<half>(sx), data_ptr<half>(k_mix),
data_ptr<half>(v_mix), data_ptr<half>(r_mix), xx.size(0), xx.size(1),
kx, vx, rx);
float* k = reinterpret_cast<float*>(wkv_y + x.numel());
float* v = k + x.size(0) * kw.size(1);
half* r = reinterpret_cast<half*>(v + x.size(0) * vw.size(1));
gemm_fp16_cublas(kx, kw.data_ptr(), k, x.size(0), kw.size(1), kw.size(0), true);
gemm_fp16_cublas(vx, vw.data_ptr(), v, x.size(0), vw.size(1), vw.size(0), true);
gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), rw.size(0), false);
element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1));
cuda_wkv_forward_new(1, x.size(0), x.size(1), data_ptr<float>(t_decay),
data_ptr<float>(t_first), k, v, r,
wkv_y, data_ptr<float>(aa),
data_ptr<float>(bb), data_ptr<float>(pp));
element_wise(InplaceMul{wkv_y, r}, x.numel());
gemm_fp16_cublas(wkv_y, ow.data_ptr(), x_plus_out.data_ptr(), x.size(0), ow.size(1), ow.size(0), false);
x_plus_out += x;
return xx;
}

View File

@ -0,0 +1,21 @@
#include <cassert>
#include <cstddef>
#include <cstdint>
template <typename Func> __global__ void _element_wise(Func func, int n) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
i += blockDim.x * gridDim.x) {
func(i);
}
}
// NOTE: packed data type (e.g. float4) is a overkill for current sizes
// (4096 in 7B model and 768 in 0.1B model),
// and is not faster than the plain float version.
template <typename Func>
void element_wise(Func func, int n) {
// 256 is good enough on most GPUs
const int32_t BLOCK_SIZE = 256;
assert(n % BLOCK_SIZE == 0);
_element_wise<<<n / BLOCK_SIZE, BLOCK_SIZE>>>(func, n);
}

165
backend-python/rwkv_pip/beta/cuda/ffn.cu vendored Normal file
View File

@ -0,0 +1,165 @@
#include "ATen/ATen.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "element_wise.h"
#include "util.h"
using torch::Tensor;
void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m,
int ori_n, int ori_k, bool output_fp32);
__global__ void _ffn_seq_mix(const half *xx, const half *sx, const half *k_mix,
const half *r_mix, const int outer_size,
const int inner_size, half *kx, half *rx) {
for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size;
idx2 += blockDim.x * gridDim.x) {
half k_mix_ = k_mix[idx2];
half r_mix_ = r_mix[idx2];
for (int row = 0; row < outer_size; ++row) {
int idx1 = row * inner_size + idx2;
half xx_ = xx[idx1];
half sx_ = sx[idx1];
kx[idx1] = __hadd(__hmul(xx_, k_mix_),
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
rx[idx1] = __hadd(__hmul(xx_, r_mix_),
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
}
}
}
void ffn_seq_mix(const half *xx, const half *sx, const half *k_mix,
const half *r_mix, const int outer_size, const int inner_size,
half *kx, half *rx) {
// 256 is good enough on most GPUs
const int32_t BLOCK_SIZE = 256;
assert(inner_size % BLOCK_SIZE == 0);
_ffn_seq_mix<<<inner_size / BLOCK_SIZE, BLOCK_SIZE>>>(
xx, sx, k_mix, r_mix, outer_size, inner_size, kx, rx);
}
struct InplaceSigmoid {
__device__ __forceinline__ void operator()(int i) const {
ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i]))));
}
half *ptr;
};
struct InplaceReLUAndSquare {
__device__ __forceinline__ void operator()(int i) const {
// __hmax is not defined in old cuda
if (__hgt(ptr[i], __float2half(0))) {
ptr[i] = __hmul(ptr[i], ptr[i]);
} else {
ptr[i] = __float2half(0);
}
}
half *ptr;
};
struct InplaceFma {
__device__ __forceinline__ void operator()(int i) const {
a[i] = __hfma(a[i], b[i], c[i]);
}
half *a;
const half *b;
const half *c;
};
/*
Equivalent Python code:
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw))
vx = torch.square(torch.relu(gemm(kx, kw)))
out = r * gemm(vx, vw)
return x + out, xx[-1,:]
*/
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
/* imm */ Tensor buf,
/* out */ Tensor x_plus_out) {
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0);
char *buf_ptr = (char *)buf.data_ptr();
half *kx = (half *)buf_ptr;
half *rx = kx + x.numel();
half *vx = rx + x.numel();
half *r = vx + x.size(0) * kw.size(1);
ffn_seq_mix(data_ptr<half>(xx), data_ptr<half>(sx), data_ptr<half>(k_mix),
data_ptr<half>(r_mix), xx.size(0), xx.size(1), kx, rx);
gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), x.size(1),
false);
element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1));
gemm_fp16_cublas(kx, kw.data_ptr(), vx, x.size(0), kw.size(1), x.size(1),
false);
element_wise(InplaceReLUAndSquare{vx}, x.size(0) * kw.size(1));
gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), x.size(0),
vw.size(1), vw.size(0), false);
element_wise(InplaceFma{data_ptr<half>(x_plus_out), r, data_ptr<half>(x)},
x_plus_out.numel());
return xx;
}
struct FfnOneMix {
__device__ __forceinline__ void operator()(int idx) {
half k_mix_ = k_mix[idx];
half r_mix_ = r_mix[idx];
half xx_ = xx[idx];
half sx_ = sx[idx];
kx[idx] = __hadd(__hmul(xx_, k_mix_),
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
rx[idx] = __hadd(__hmul(xx_, r_mix_),
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
}
half *k_mix;
half *r_mix;
half *xx;
half *sx;
half *kx;
half *rx;
};
/*
Equivalent Python code:
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
kx = xx * k_mix + sx * (1 - k_mix)
rx = xx * r_mix + sx * (1 - r_mix)
r = torch.sigmoid(gemm(rx, rw))
vx = torch.square(torch.relu(gemm(kx, kw)))
out = r * gemm(vx, vw)
return x + out, xx
*/
Tensor ffn_one(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
/* imm */ Tensor buf,
/* out */ Tensor x_plus_out) {
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
char *buf_ptr = (char *)buf.data_ptr();
half *kx = (half *)buf_ptr;
half *rx = kx + x.numel();
half *vx = rx + x.numel();
half *r = vx + x.size(0) * kw.size(1);
element_wise(FfnOneMix{data_ptr<half>(k_mix), data_ptr<half>(r_mix),
data_ptr<half>(xx), data_ptr<half>(sx), kx, rx},
x.numel());
// vector * matrix, so m = 1
gemm_fp16_cublas(rx, rw.data_ptr(), r, 1, rw.size(1), rw.size(0), false);
element_wise(InplaceSigmoid{r}, rw.size(1));
gemm_fp16_cublas(kx, kw.data_ptr(), vx, 1, kw.size(1), kw.size(0), false);
element_wise(InplaceReLUAndSquare{vx}, kw.size(1));
gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), 1, vw.size(1),
vw.size(0), false);
element_wise(InplaceFma{data_ptr<half>(x_plus_out), r, data_ptr<half>(x)},
x_plus_out.numel());
return xx;
}

View File

@ -0,0 +1,80 @@
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#define CUBLAS_CHECK(condition) \
for (cublasStatus_t _cublas_check_status = (condition); \
_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
throw std::runtime_error("cuBLAS error " + \
std::to_string(_cublas_check_status) + " at " + \
std::to_string(__LINE__));
#define CUDA_CHECK(condition) \
for (cudaError_t _cuda_check_status = (condition); \
_cuda_check_status != cudaSuccess;) \
throw std::runtime_error( \
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
" at " + std::to_string(__LINE__));
cublasHandle_t get_cublas_handle() {
static cublasHandle_t cublas_handle = []() {
cublasHandle_t handle = nullptr;
CUBLAS_CHECK(cublasCreate(&handle));
#if CUDA_VERSION < 11000
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#else
CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
return handle;
}();
return cublas_handle;
}
/*
NOTE: blas gemm is column-major by default, but we need row-major output.
The data of row-major, transposed matrix is exactly the same as the
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
*/
void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m,
int ori_n, int ori_k, bool output_fp32) {
const auto cuda_data_type = CUDA_R_16F;
const auto cuda_c_data_type = output_fp32 ? CUDA_R_32F : CUDA_R_16F;
const auto compute_type = CUDA_R_32F;
const float sp_alpha = 1.f;
// use CUBLAS_OP_N. see the notes above
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
// m = (B^T).size(0) = B.size(1) = n;
const int cublas_m = ori_n;
const int cublas_k = ori_k;
// comptiable with rwkv one mode, where 1-D tensor * 2-D tensor
// const int n = a.dense_dim() == 1 ? 1 : a.size(0);
const int cublas_n = ori_m;
const int cublas_lda = cublas_m;
const int cublas_ldb = cublas_k;
const int cublas_ldc = cublas_m;
cublasHandle_t cublas_handle = get_cublas_handle();
#if CUDA_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
#endif
const float sp_beta = 0.f;
CUBLAS_CHECK(cublasGemmEx(
cublas_handle, cublas_trans_a, cublas_trans_b, cublas_m, cublas_n,
cublas_k, &sp_alpha, b, cuda_data_type, cublas_lda,
a, cuda_data_type, cublas_ldb, &sp_beta, c,
cuda_c_data_type, cublas_ldc, compute_type, algo));
}
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
// comptiable with rwkv one mode, 1-D tensor * 2-D tensor
const int m = a.dense_dim() == 1 ? 1 : a.size(0);
const int n = b.size(1);
const int k = b.size(0);
gemm_fp16_cublas(a.data_ptr(), b.data_ptr(), c.data_ptr(), m, n, k,
c.dtype() == torch::kFloat32);
}

View File

@ -0,0 +1,246 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
#include <cuda_fp16.h>
#define MIN_VALUE (-1e38)
typedef at::Half fp16;
__half *cast(fp16 *ptr) {
return reinterpret_cast<__half *>(ptr);
}
template <typename F>
__global__ void kernel_wkv_forward(const int B, const int T, const int C,
const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;
const int _state_offset = _b * C + _c;
float u = _u[_c];
float w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;
float aa = _aa[_state_offset];
float bb = _bb[_state_offset];
float pp = _pp[_state_offset];
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
_aa[_state_offset] = aa;
_bb[_state_offset] = bb;
_pp[_state_offset] = pp;
}
template <typename F>
void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
dim3 threadsPerBlock( min(C, 32) );
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
}
template void cuda_wkv_forward<fp16>(
int B, int T, int C,
float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
float *aa, float *bb, float *pp);
template void cuda_wkv_forward<float>(
int B, int T, int C,
float *w, float *u, float *k, float *v, float *y,
float *aa, float *bb, float *pp);
__global__ void kernel_mm_seq_fp32i8(
const int B, const int N, const int M,
const float *__restrict__ const x, const int x_stride,
const uint8_t *__restrict__ const w, const int w_stride,
const float *__restrict__ const mx,
const float *__restrict__ const rx,
const float *__restrict__ const my,
const float *__restrict__ const ry,
float *__restrict__ const y, const int y_stride) {
const int i = blockIdx.x * blockDim.x + threadIdx.x;
const int k = blockIdx.y * blockDim.y + threadIdx.y;
if (i < B && k < M) {
float y_local = 0;
for (int j = 0; j < N; ++j) {
y_local += x[i * x_stride + j] * (
(float(w[j * w_stride + k]) + 0.5f)
* rx[k] * ry[j] + mx[k] + my[j]
);
}
y[i * y_stride + k] = y_local;
}
}
template <typename F>
void cuda_mm8_seq(int B, int N, int M,
F *x, int x_stride,
uint8_t *w, int w_stride,
F *mx, F *rx,
F *my, F *ry,
F *y, int y_stride);
template <>
void cuda_mm8_seq<float>(int B, int N, int M,
float *x, int x_stride,
uint8_t *w, int w_stride,
float *mx, float *rx,
float *my, float *ry,
float *y, int y_stride) {
dim3 blockSize(1, 128);
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
B, N, M, x, x_stride, w, w_stride,
mx, rx, my, ry, y, y_stride);
}
__global__ void kernel_mm_seq_fp16i8(
const int B, const int N, const int M,
const __half *__restrict__ const x, const int x_stride,
const uint8_t *__restrict__ const w, const int w_stride,
const __half *__restrict__ const mx,
const __half *__restrict__ const rx,
const __half *__restrict__ const my,
const __half *__restrict__ const ry,
__half *__restrict__ const y, const int y_stride) {
const int i = blockIdx.x * blockDim.x + threadIdx.x;
const int k = blockIdx.y * blockDim.y + threadIdx.y;
if (i < B && k < M) {
float y_local = 0;
for (int j = 0; j < N; ++j) {
y_local += __half2float(x[i * x_stride + j]) * (
(float(w[j * w_stride + k]) + 0.5f)
* __half2float(rx[k]) * __half2float(ry[j])
+ __half2float(mx[k]) + __half2float(my[j])
);
}
y[i * y_stride + k] = __float2half(y_local);
}
}
template <>
void cuda_mm8_seq<fp16>(int B, int N, int M,
fp16 *x, int x_stride,
uint8_t *w, int w_stride,
fp16 *mx, fp16 *rx,
fp16 *my, fp16 *ry,
fp16 *y, int y_stride) {
dim3 blockSize(1, 128);
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
B, N, M, cast(x), x_stride, w, w_stride,
cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
}
#define MM8_ONE_JSPLIT 24
#define MM8_ONE_TILE 1024
__global__ void kernel_mm_one_fp32i8(
const int N, const int M,
const float *__restrict__ const x,
const uint8_t *__restrict__ const w, const int w_stride,
const float *__restrict__ const mx,
const float *__restrict__ const rx,
const float *__restrict__ const my,
const float *__restrict__ const ry,
float *__restrict__ const y) {
const int k = blockIdx.y * blockDim.y + threadIdx.y;
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
if (k < M) {
float y_local = 0;
for (int j = j0; j < j1; ++j) {
y_local += x[j] * (
(float(w[j * w_stride + k]) + 0.5f)
* rx[k] * ry[j] + mx[k] + my[j]
);
}
atomicAdd(&y[k], y_local);
}
}
template <typename F>
void cuda_mm8_one(int N, int M,
F *x,
uint8_t *w, int w_stride,
F *mx, F *rx,
F *my, F *ry,
float *y);
template <>
void cuda_mm8_one<float>(int N, int M,
float *x,
uint8_t *w, int w_stride,
float *mx, float *rx,
float *my, float *ry,
float *y) {
dim3 blockSize(1, MM8_ONE_TILE);
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
N, M, x, w, w_stride,
mx, rx, my, ry, y);
}
__global__ void kernel_mm_one_fp16i8(
const int N, const int M,
const __half *__restrict__ const x,
const uint8_t *__restrict__ const w, const int w_stride,
const __half *__restrict__ const mx,
const __half *__restrict__ const rx,
const __half *__restrict__ const my,
const __half *__restrict__ const ry,
float *__restrict__ const y) {
const int k = blockIdx.y * blockDim.y + threadIdx.y;
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
if (k < M) {
float y_local = 0;
for (int j = j0; j < j1; ++j) {
y_local += __half2float(x[j]) * (
(float(w[j * w_stride + k]) + 0.5f)
* __half2float(rx[k]) * __half2float(ry[j])
+ __half2float(mx[k]) + __half2float(my[j])
);
}
atomicAdd(&y[k], y_local);
}
}
template <>
void cuda_mm8_one<fp16>(int N, int M,
fp16 *x,
uint8_t *w, int w_stride,
fp16 *mx, fp16 *rx,
fp16 *my, fp16 *ry,
float *y) {
dim3 blockSize(1, MM8_ONE_TILE);
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
N, M, cast(x), w, w_stride,
cast(mx), cast(rx), cast(my), cast(ry), y);
}

View File

@ -0,0 +1,7 @@
#include "ATen/ATen.h"
#include <cuda_fp16.h>
template <typename T> T *data_ptr(torch::Tensor x) { return x.data_ptr<T>(); }
template <> inline half *data_ptr(torch::Tensor x) {
return reinterpret_cast<half *>(x.data_ptr<at::Half>());
}

View File

@ -0,0 +1,167 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
#include <iostream>
#include <c10/cuda/CUDAGuard.h>
typedef at::Half fp16;
template <typename F>
void cuda_wkv_forward(int B, int T, int C,
float *w, float *u, F *k, F *v, F *y,
float *aa, float *bb, float *pp);
template <typename F>
void cuda_mm8_seq(int B, int N, int M,
F *x, int x_stride,
uint8_t *w, int w_stride,
F *mx, F *rx,
F *my, F *ry,
F *y, int y_stride);
template <typename F>
void cuda_mm8_one(int N, int M,
F *x,
uint8_t *w, int w_stride,
F *mx, F *rx,
F *my, F *ry,
float *y);
void wkv_forward(int64_t B, int64_t T, int64_t C,
torch::Tensor &w, torch::Tensor &u,
torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
switch (k.scalar_type()) {
case c10::ScalarType::Half:
cuda_wkv_forward(B, T, C,
w.data_ptr<float>(), u.data_ptr<float>(),
k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
break;
case c10::ScalarType::Float:
cuda_wkv_forward(B, T, C,
w.data_ptr<float>(), u.data_ptr<float>(),
k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
break;
default:
assert(false && "Only FP16 and FP32 are currently supported");
}
}
void mm8_seq(int64_t B, int64_t N, int64_t M,
torch::Tensor &x, torch::Tensor &w,
torch::Tensor &mx, torch::Tensor &rx,
torch::Tensor &my, torch::Tensor &ry,
torch::Tensor &y) {
assert(x.stride(1) == 1);
assert(w.stride(1) == 1);
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
assert(my.stride(0) == 1 && ry.stride(0) == 1);
assert(y.stride(1) == 1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
switch (x.scalar_type()) {
case c10::ScalarType::Half:
cuda_mm8_seq(
B, N, M,
x.data_ptr<fp16>(), x.stride(0),
w.data_ptr<uint8_t>(), w.stride(0),
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
y.data_ptr<fp16>(), y.stride(0));
break;
case c10::ScalarType::Float:
cuda_mm8_seq(
B, N, M,
x.data_ptr<float>(), x.stride(0),
w.data_ptr<uint8_t>(), w.stride(0),
mx.data_ptr<float>(), rx.data_ptr<float>(),
my.data_ptr<float>(), ry.data_ptr<float>(),
y.data_ptr<float>(), y.stride(0));
break;
default:
assert(false && "Only FP16 and FP32 are currently supported");
}
}
void mm8_one(int64_t N, int64_t M,
torch::Tensor &x, torch::Tensor &w,
torch::Tensor &mx, torch::Tensor &rx,
torch::Tensor &my, torch::Tensor &ry,
torch::Tensor &y) {
assert(x.stride(0) == 1);
assert(w.stride(1) == 1);
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
assert(my.stride(0) == 1 && ry.stride(0) == 1);
assert(y.stride(0) == 1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
switch (x.scalar_type()) {
case c10::ScalarType::Half:
cuda_mm8_one(
N, M,
x.data_ptr<fp16>(),
w.data_ptr<uint8_t>(), w.stride(0),
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
y.data_ptr<float>());
break;
case c10::ScalarType::Float:
cuda_mm8_one(
N, M,
x.data_ptr<float>(),
w.data_ptr<uint8_t>(), w.stride(0),
mx.data_ptr<float>(), rx.data_ptr<float>(),
my.data_ptr<float>(), ry.data_ptr<float>(),
y.data_ptr<float>());
break;
default:
assert(false && "Only FP16 and FP32 are currently supported");
}
}
using torch::Tensor;
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
Tensor v_mix, Tensor r_mix, Tensor kw,
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw,
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
/* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb,
Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r,
/* out */ Tensor x_plus_out, /* out */ Tensor t1,
/* out */ Tensor t2, /* out */ Tensor p);
Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out);
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
/* imm */ Tensor buf,
/* out */ Tensor x_plus_out);
Tensor ffn_one(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
/* imm */ Tensor buf,
/* out */ Tensor x_plus_out);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wkv_forward", &wkv_forward, "wkv forward");
m.def("mm8_seq", &mm8_seq, "mm8 seq");
m.def("mm8_one", &mm8_one, "mm8 one");
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
m.def("att_one", &att_one, "att one");
m.def("att_seq", &att_seq, "att seq");
m.def("ffn_seq", &ffn_seq, "ffn seq");
m.def("ffn_one", &ffn_one, "ffn one");
}
TORCH_LIBRARY(rwkv, m) {
m.def("wkv_forward", wkv_forward);
m.def("mm8_seq", mm8_seq);
m.def("mm8_one", mm8_one);
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
m.def("att_one", att_one);
m.def("att_seq", att_seq);
m.def("ffn_seq", ffn_seq);
m.def("ffn_one", ffn_one);
}

1479
backend-python/rwkv_pip/beta/model.py vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -10,6 +10,7 @@ from fastapi import HTTPException
from pydantic import BaseModel, Field
import numpy as np
from routes import state_cache
import global_var
END_OF_TEXT = 0
@ -27,7 +28,17 @@ class RWKVType(Enum):
class AbstractRWKV(ABC):
def __init__(self, model: str, strategy: str, tokens_path: str):
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
# dynamic import to make RWKV_CUDA_ON work
if rwkv_beta:
from rwkv_pip.beta.model import (
RWKV as Model,
)
else:
from rwkv.model import (
RWKV as Model,
)
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model))
@ -221,7 +232,7 @@ class AbstractRWKV(ABC):
return state[0].tolist(), token_len
def generate(
self, prompt: str, stop: Union[str, List[str]] = None
self, prompt: str, stop: Union[str, List[str], None] = None
) -> Iterable[Tuple[str, str, int, int]]:
quick_log(None, None, "Generation Prompt:\n" + prompt)
cache = None
@ -438,8 +449,10 @@ The following is a coherent verbose detailed conversation between a girl named {
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
if self.rwkv_type == RWKVType.Raven
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
else (
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
)
logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
try:

View File

@ -128,6 +128,7 @@
"Chinese Kongfu": "中国武術",
"Allow external access to the API (service must be restarted)": "APIへの外部アクセスを許可する (サービスを再起動する必要があります)",
"Custom": "カスタム",
"CUDA (Beta, Faster)": "CUDA (ベータ、高速)",
"Reset All Configs": "すべての設定をリセット",
"Cancel": "キャンセル",
"Confirm": "確認",

View File

@ -128,6 +128,7 @@
"Chinese Kongfu": "情境冒险",
"Allow external access to the API (service must be restarted)": "允许外部访问API (必须重启服务)",
"Custom": "自定义",
"CUDA (Beta, Faster)": "CUDA (Beta, 更快)",
"Reset All Configs": "重置所有配置",
"Cancel": "取消",
"Confirm": "确认",

View File

@ -85,7 +85,9 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
await exit(1000).catch(() => {
});
StartServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1').catch((e) => {
StartServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
modelConfig.modelParameters.device === 'CUDA-Beta'
).catch((e) => {
const errMsg = e.message || e;
if (errMsg.includes('path contains space'))
toast(`${t('Error')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
@ -118,7 +120,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
const strategy = getStrategy(modelConfig);
let customCudaFile = '';
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'Custom')
if ((modelConfig.modelParameters.device.includes('CUDA') || modelConfig.modelParameters.device === 'Custom')
&& modelConfig.modelParameters.useCustomCuda && !strategy.includes('fp32')) {
if (commonStore.platform === 'windows') {
customCudaFile = getSupportedCustomCudaFile();

View File

@ -30,7 +30,7 @@ export type ApiParameters = {
frequencyPenalty: number;
}
export type Device = 'CPU' | 'CUDA' | 'MPS' | 'Custom';
export type Device = 'CPU' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
export type Precision = 'fp16' | 'int8' | 'fp32';
export type ModelParameters = {
@ -284,6 +284,8 @@ export const Configs: FC = observer(() => {
<Option value="CPU">CPU</Option>
{commonStore.platform === 'darwin' && <Option value="MPS">MPS</Option>}
<Option value="CUDA">CUDA</Option>
<Option value="CUDA-Beta">{t('CUDA (Beta, Faster)')!}</Option>
<Option value="WebGPU" disabled>WebGPU</Option>
<Option value="Custom">{t('Custom')!}</Option>
</Dropdown>
} />
@ -308,12 +310,12 @@ export const Configs: FC = observer(() => {
} />
}
{
selectedConfig.modelParameters.device == 'CUDA' &&
selectedConfig.modelParameters.device.includes('CUDA') &&
<Labeled label={t('Current Strategy')}
content={<Text> {getStrategy(selectedConfig)} </Text>} />
}
{
selectedConfig.modelParameters.device == 'CUDA' &&
selectedConfig.modelParameters.device.includes('CUDA') &&
<Labeled label={t('Stored Layers')}
desc={t('Number of the neural network layers loaded into VRAM, the more you load, the faster the speed, but it consumes more VRAM. (If your VRAM is not enough, it will fail to load)')}
content={
@ -327,7 +329,7 @@ export const Configs: FC = observer(() => {
} />
}
{
selectedConfig.modelParameters.device == 'CUDA' && <div />
selectedConfig.modelParameters.device.includes('CUDA') && <div />
}
{
displayStrategyImg &&

View File

@ -177,6 +177,7 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
break;
case 'CUDA':
case 'CUDA-Beta':
if (avoidOverflow)
strategy = 'cuda fp32 *1 -> ';
strategy += 'cuda ';

View File

@ -46,7 +46,7 @@ export function RestartApp():Promise<void>;
export function SaveJson(arg1:string,arg2:any):Promise<void>;
export function StartServer(arg1:string,arg2:number,arg3:string):Promise<string>;
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean):Promise<string>;
export function UpdateApp(arg1:string):Promise<boolean>;

View File

@ -90,8 +90,8 @@ export function SaveJson(arg1, arg2) {
return window['go']['backend_golang']['App']['SaveJson'](arg1, arg2);
}
export function StartServer(arg1, arg2, arg3) {
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3);
export function StartServer(arg1, arg2, arg3, arg4) {
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4);
}
export function UpdateApp(arg1) {