upgrade cuda-beta

This commit is contained in:
josc146
2023-09-15 16:30:11 +08:00
parent c4042bbfd8
commit df969fcfc6
6 changed files with 601 additions and 90 deletions

View File

@@ -88,7 +88,7 @@ struct Mix {
using torch::Tensor;
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
void gemm_fp16_cublas_tensor(Tensor a, Tensor b, Tensor c);
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
Tensor v_mix, Tensor r_mix, Tensor kw,
@@ -105,9 +105,9 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
data_ptr<half>(vx), data_ptr<half>(rx)},
x.numel());
gemm_fp16_cublas(kx, kw, k);
gemm_fp16_cublas(vx, vw, v);
gemm_fp16_cublas(rx, rw, r);
gemm_fp16_cublas_tensor(kx, kw, k);
gemm_fp16_cublas_tensor(vx, vw, v);
gemm_fp16_cublas_tensor(rx, rw, r);
at::sigmoid_(r);
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
@@ -118,7 +118,7 @@ Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
data_ptr<half>(r)},
x.numel());
gemm_fp16_cublas(r, ow, x_plus_out);
gemm_fp16_cublas_tensor(r, ow, x_plus_out);
x_plus_out += x;
return xx;
}