diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py index 75c543e..c92f94d 100644 --- a/backend-python/rwkv_pip/model.py +++ b/backend-python/rwkv_pip/model.py @@ -12,6 +12,56 @@ torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True current_path = os.path.dirname(os.path.abspath(__file__)) + +# https://zhuanlan.zhihu.com/p/612879065 +def LoadPreCompileLibrary(file): + import importlib + import os + + import torch + + # load the custom_op_library and register the custom ops + lib_dir = os.path.dirname(__file__) + if os.name == "nt": + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise ValueError(err) + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = ( + importlib.machinery.ExtensionFileLoader, + importlib.machinery.EXTENSION_SUFFIXES, + ) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec(file) + if ext_specs is None: + return False + + try: + torch.ops.load_library(ext_specs.origin) + except OSError as exc: + return False + return True + + ######################################################################################################## if os.environ.get("RWKV_JIT_ON") != "0": @@ -29,45 +79,48 @@ else: MyStatic = __nop if os.environ.get("RWKV_CUDA_ON") == "1": - from torch.utils.cpp_extension import load + DISABLE_CUBLAS_GEMM = False + from torch.utils.cpp_extension import load # L581 - try: - load( - name=f"wkv_cuda", - sources=[ - f"{current_path}/cuda/wrapper.cpp", - f"{current_path}/cuda/operators.cu", - f"{current_path}/cuda/gemm_fp16_cublas.cpp", - ], - verbose=True, - extra_cuda_cflags=[ - "--use_fast_math", - "-O3", - "--extra-device-vectorization", - ], - is_python_module=False, - ) - DISABLE_CUBLAS_GEMM = False - except: - print( - "Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow." - ) - load( - name=f"wkv_cuda", - sources=[ - f"{current_path}/cuda/wrapper.cpp", - f"{current_path}/cuda/operators.cu", - ], - verbose=True, - extra_cuda_cflags=[ - "--use_fast_math", - "-O3", - "--extra-device-vectorization", - ], - extra_cflags=["-DDISABLE_CUBLAS_GEMM"], - is_python_module=False, - ) - DISABLE_CUBLAS_GEMM = True + if LoadPreCompileLibrary("wkv_cuda") is False: + try: + load( + name=f"wkv_cuda", + sources=[ + f"{current_path}/cuda/wrapper.cpp", + f"{current_path}/cuda/operators.cu", + f"{current_path}/cuda/gemm_fp16_cublas.cpp", + ], + verbose=True, + extra_ldflags=["cublas.lib"], + extra_cuda_cflags=[ + "--use_fast_math", + "-O3", + "--extra-device-vectorization", + ], + is_python_module=False, + ) + DISABLE_CUBLAS_GEMM = False + except: + print( + "Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow." + ) + load( + name=f"wkv_cuda", + sources=[ + f"{current_path}/cuda/wrapper.cpp", + f"{current_path}/cuda/operators.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "--use_fast_math", + "-O3", + "--extra-device-vectorization", + ], + extra_cflags=["-DDISABLE_CUBLAS_GEMM"], + is_python_module=False, + ) + DISABLE_CUBLAS_GEMM = True @MyStatic def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp): @@ -527,24 +580,26 @@ class RWKV(MyModule): if self.version == 5.2: assert ( os.environ["RWKV_CUDA_ON"] == "1" - ) # latest RWKV-5 requires os.environ["RWKV_CUDA_ON"] == '1' (will fix soon) + ), "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon)" HEAD_SIZE = args.n_att // args.n_head - rwkv5 = load( - name="rwkv5", - sources=[ - f"{current_path}/cuda/rwkv5_op.cpp", - f"{current_path}/cuda/rwkv5.cu", - ], - verbose=True, - extra_cuda_cflags=[ - "-res-usage", - "--use_fast_math", - "-O3", - "-Xptxas -O3", - "--extra-device-vectorization", - f"-D_N_={HEAD_SIZE}", - ], - ) + if LoadPreCompileLibrary("rwkv5") is True: + rwkv5 = torch.ops.rwkv5 + else: + rwkv5 = load( + name="rwkv5", + sources=[ + f"{current_path}/cuda/rwkv5_op.cpp", + f"{current_path}/cuda/rwkv5.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + ], + ) class RWKV_5(torch.autograd.Function): @staticmethod diff --git a/backend-python/rwkv_pip/rwkv5.pyd b/backend-python/rwkv_pip/rwkv5.pyd new file mode 100644 index 0000000..79f7baa Binary files /dev/null and b/backend-python/rwkv_pip/rwkv5.pyd differ diff --git a/backend-python/rwkv_pip/wkv_cuda.pyd b/backend-python/rwkv_pip/wkv_cuda.pyd new file mode 100644 index 0000000..324ed47 Binary files /dev/null and b/backend-python/rwkv_pip/wkv_cuda.pyd differ