rwkv5 pre-compiled kernel (for windows)

This commit is contained in:
josc146 2023-10-03 13:39:07 +08:00
parent 79851433f8
commit 68228a4552
3 changed files with 110 additions and 55 deletions

View File

@ -12,6 +12,56 @@ torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
current_path = os.path.dirname(os.path.abspath(__file__))
# https://zhuanlan.zhihu.com/p/612879065
def LoadPreCompileLibrary(file):
import importlib
import os
import torch
# load the custom_op_library and register the custom ops
lib_dir = os.path.dirname(__file__)
if os.name == "nt":
# Register the main torchvision library location on the default DLL path
import ctypes
import sys
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)
if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p
if sys.version_info >= (3, 8):
os.add_dll_directory(lib_dir)
elif with_load_library_flags:
res = kernel32.AddDllDirectory(lib_dir)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
raise ValueError(err)
kernel32.SetErrorMode(prev_error_mode)
loader_details = (
importlib.machinery.ExtensionFileLoader,
importlib.machinery.EXTENSION_SUFFIXES,
)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(file)
if ext_specs is None:
return False
try:
torch.ops.load_library(ext_specs.origin)
except OSError as exc:
return False
return True
########################################################################################################
if os.environ.get("RWKV_JIT_ON") != "0":
@ -29,45 +79,48 @@ else:
MyStatic = __nop
if os.environ.get("RWKV_CUDA_ON") == "1":
from torch.utils.cpp_extension import load
DISABLE_CUBLAS_GEMM = False
from torch.utils.cpp_extension import load # L581
try:
load(
name=f"wkv_cuda",
sources=[
f"{current_path}/cuda/wrapper.cpp",
f"{current_path}/cuda/operators.cu",
f"{current_path}/cuda/gemm_fp16_cublas.cpp",
],
verbose=True,
extra_cuda_cflags=[
"--use_fast_math",
"-O3",
"--extra-device-vectorization",
],
is_python_module=False,
)
DISABLE_CUBLAS_GEMM = False
except:
print(
"Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow."
)
load(
name=f"wkv_cuda",
sources=[
f"{current_path}/cuda/wrapper.cpp",
f"{current_path}/cuda/operators.cu",
],
verbose=True,
extra_cuda_cflags=[
"--use_fast_math",
"-O3",
"--extra-device-vectorization",
],
extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
is_python_module=False,
)
DISABLE_CUBLAS_GEMM = True
if LoadPreCompileLibrary("wkv_cuda") is False:
try:
load(
name=f"wkv_cuda",
sources=[
f"{current_path}/cuda/wrapper.cpp",
f"{current_path}/cuda/operators.cu",
f"{current_path}/cuda/gemm_fp16_cublas.cpp",
],
verbose=True,
extra_ldflags=["cublas.lib"],
extra_cuda_cflags=[
"--use_fast_math",
"-O3",
"--extra-device-vectorization",
],
is_python_module=False,
)
DISABLE_CUBLAS_GEMM = False
except:
print(
"Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow."
)
load(
name=f"wkv_cuda",
sources=[
f"{current_path}/cuda/wrapper.cpp",
f"{current_path}/cuda/operators.cu",
],
verbose=True,
extra_cuda_cflags=[
"--use_fast_math",
"-O3",
"--extra-device-vectorization",
],
extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
is_python_module=False,
)
DISABLE_CUBLAS_GEMM = True
@MyStatic
def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp):
@ -527,24 +580,26 @@ class RWKV(MyModule):
if self.version == 5.2:
assert (
os.environ["RWKV_CUDA_ON"] == "1"
) # latest RWKV-5 requires os.environ["RWKV_CUDA_ON"] == '1' (will fix soon)
), "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon)"
HEAD_SIZE = args.n_att // args.n_head
rwkv5 = load(
name="rwkv5",
sources=[
f"{current_path}/cuda/rwkv5_op.cpp",
f"{current_path}/cuda/rwkv5.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={HEAD_SIZE}",
],
)
if LoadPreCompileLibrary("rwkv5") is True:
rwkv5 = torch.ops.rwkv5
else:
rwkv5 = load(
name="rwkv5",
sources=[
f"{current_path}/cuda/rwkv5_op.cpp",
f"{current_path}/cuda/rwkv5.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"--extra-device-vectorization",
f"-D_N_={HEAD_SIZE}",
],
)
class RWKV_5(torch.autograd.Function):
@staticmethod

BIN
backend-python/rwkv_pip/rwkv5.pyd vendored Normal file

Binary file not shown.

BIN
backend-python/rwkv_pip/wkv_cuda.pyd vendored Normal file

Binary file not shown.