rwkv5 pre-compiled kernel (for windows)

This commit is contained in:
josc146 2023-10-03 13:39:07 +08:00
parent 79851433f8
commit 68228a4552
3 changed files with 110 additions and 55 deletions

View File

@ -12,6 +12,56 @@ torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
current_path = os.path.dirname(os.path.abspath(__file__)) current_path = os.path.dirname(os.path.abspath(__file__))
# https://zhuanlan.zhihu.com/p/612879065
def LoadPreCompileLibrary(file):
import importlib
import os
import torch
# load the custom_op_library and register the custom ops
lib_dir = os.path.dirname(__file__)
if os.name == "nt":
# Register the main torchvision library location on the default DLL path
import ctypes
import sys
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)
if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p
if sys.version_info >= (3, 8):
os.add_dll_directory(lib_dir)
elif with_load_library_flags:
res = kernel32.AddDllDirectory(lib_dir)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
raise ValueError(err)
kernel32.SetErrorMode(prev_error_mode)
loader_details = (
importlib.machinery.ExtensionFileLoader,
importlib.machinery.EXTENSION_SUFFIXES,
)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(file)
if ext_specs is None:
return False
try:
torch.ops.load_library(ext_specs.origin)
except OSError as exc:
return False
return True
######################################################################################################## ########################################################################################################
if os.environ.get("RWKV_JIT_ON") != "0": if os.environ.get("RWKV_JIT_ON") != "0":
@ -29,45 +79,48 @@ else:
MyStatic = __nop MyStatic = __nop
if os.environ.get("RWKV_CUDA_ON") == "1": if os.environ.get("RWKV_CUDA_ON") == "1":
from torch.utils.cpp_extension import load DISABLE_CUBLAS_GEMM = False
from torch.utils.cpp_extension import load # L581
try: if LoadPreCompileLibrary("wkv_cuda") is False:
load( try:
name=f"wkv_cuda", load(
sources=[ name=f"wkv_cuda",
f"{current_path}/cuda/wrapper.cpp", sources=[
f"{current_path}/cuda/operators.cu", f"{current_path}/cuda/wrapper.cpp",
f"{current_path}/cuda/gemm_fp16_cublas.cpp", f"{current_path}/cuda/operators.cu",
], f"{current_path}/cuda/gemm_fp16_cublas.cpp",
verbose=True, ],
extra_cuda_cflags=[ verbose=True,
"--use_fast_math", extra_ldflags=["cublas.lib"],
"-O3", extra_cuda_cflags=[
"--extra-device-vectorization", "--use_fast_math",
], "-O3",
is_python_module=False, "--extra-device-vectorization",
) ],
DISABLE_CUBLAS_GEMM = False is_python_module=False,
except: )
print( DISABLE_CUBLAS_GEMM = False
"Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow." except:
) print(
load( "Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow."
name=f"wkv_cuda", )
sources=[ load(
f"{current_path}/cuda/wrapper.cpp", name=f"wkv_cuda",
f"{current_path}/cuda/operators.cu", sources=[
], f"{current_path}/cuda/wrapper.cpp",
verbose=True, f"{current_path}/cuda/operators.cu",
extra_cuda_cflags=[ ],
"--use_fast_math", verbose=True,
"-O3", extra_cuda_cflags=[
"--extra-device-vectorization", "--use_fast_math",
], "-O3",
extra_cflags=["-DDISABLE_CUBLAS_GEMM"], "--extra-device-vectorization",
is_python_module=False, ],
) extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
DISABLE_CUBLAS_GEMM = True is_python_module=False,
)
DISABLE_CUBLAS_GEMM = True
@MyStatic @MyStatic
def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp): def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp):
@ -527,24 +580,26 @@ class RWKV(MyModule):
if self.version == 5.2: if self.version == 5.2:
assert ( assert (
os.environ["RWKV_CUDA_ON"] == "1" os.environ["RWKV_CUDA_ON"] == "1"
) # latest RWKV-5 requires os.environ["RWKV_CUDA_ON"] == '1' (will fix soon) ), "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon)"
HEAD_SIZE = args.n_att // args.n_head HEAD_SIZE = args.n_att // args.n_head
rwkv5 = load( if LoadPreCompileLibrary("rwkv5") is True:
name="rwkv5", rwkv5 = torch.ops.rwkv5
sources=[ else:
f"{current_path}/cuda/rwkv5_op.cpp", rwkv5 = load(
f"{current_path}/cuda/rwkv5.cu", name="rwkv5",
], sources=[
verbose=True, f"{current_path}/cuda/rwkv5_op.cpp",
extra_cuda_cflags=[ f"{current_path}/cuda/rwkv5.cu",
"-res-usage", ],
"--use_fast_math", verbose=True,
"-O3", extra_cuda_cflags=[
"-Xptxas -O3", "-res-usage",
"--extra-device-vectorization", "--use_fast_math",
f"-D_N_={HEAD_SIZE}", "-O3",
], "--extra-device-vectorization",
) f"-D_N_={HEAD_SIZE}",
],
)
class RWKV_5(torch.autograd.Function): class RWKV_5(torch.autograd.Function):
@staticmethod @staticmethod

BIN
backend-python/rwkv_pip/rwkv5.pyd vendored Normal file

Binary file not shown.

BIN
backend-python/rwkv_pip/wkv_cuda.pyd vendored Normal file

Binary file not shown.