rwkv5 pre-compiled kernel (for windows)
This commit is contained in:
parent
79851433f8
commit
68228a4552
165
backend-python/rwkv_pip/model.py
vendored
165
backend-python/rwkv_pip/model.py
vendored
@ -12,6 +12,56 @@ torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
# https://zhuanlan.zhihu.com/p/612879065
|
||||
def LoadPreCompileLibrary(file):
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
# load the custom_op_library and register the custom ops
|
||||
lib_dir = os.path.dirname(__file__)
|
||||
if os.name == "nt":
|
||||
# Register the main torchvision library location on the default DLL path
|
||||
import ctypes
|
||||
import sys
|
||||
|
||||
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
||||
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
||||
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
||||
|
||||
if with_load_library_flags:
|
||||
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
os.add_dll_directory(lib_dir)
|
||||
elif with_load_library_flags:
|
||||
res = kernel32.AddDllDirectory(lib_dir)
|
||||
if res is None:
|
||||
err = ctypes.WinError(ctypes.get_last_error())
|
||||
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
|
||||
raise ValueError(err)
|
||||
|
||||
kernel32.SetErrorMode(prev_error_mode)
|
||||
|
||||
loader_details = (
|
||||
importlib.machinery.ExtensionFileLoader,
|
||||
importlib.machinery.EXTENSION_SUFFIXES,
|
||||
)
|
||||
|
||||
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
||||
ext_specs = extfinder.find_spec(file)
|
||||
if ext_specs is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
torch.ops.load_library(ext_specs.origin)
|
||||
except OSError as exc:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
########################################################################################################
|
||||
|
||||
if os.environ.get("RWKV_JIT_ON") != "0":
|
||||
@ -29,45 +79,48 @@ else:
|
||||
MyStatic = __nop
|
||||
|
||||
if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||
from torch.utils.cpp_extension import load
|
||||
DISABLE_CUBLAS_GEMM = False
|
||||
from torch.utils.cpp_extension import load # L581
|
||||
|
||||
try:
|
||||
load(
|
||||
name=f"wkv_cuda",
|
||||
sources=[
|
||||
f"{current_path}/cuda/wrapper.cpp",
|
||||
f"{current_path}/cuda/operators.cu",
|
||||
f"{current_path}/cuda/gemm_fp16_cublas.cpp",
|
||||
],
|
||||
verbose=True,
|
||||
extra_cuda_cflags=[
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"--extra-device-vectorization",
|
||||
],
|
||||
is_python_module=False,
|
||||
)
|
||||
DISABLE_CUBLAS_GEMM = False
|
||||
except:
|
||||
print(
|
||||
"Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow."
|
||||
)
|
||||
load(
|
||||
name=f"wkv_cuda",
|
||||
sources=[
|
||||
f"{current_path}/cuda/wrapper.cpp",
|
||||
f"{current_path}/cuda/operators.cu",
|
||||
],
|
||||
verbose=True,
|
||||
extra_cuda_cflags=[
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"--extra-device-vectorization",
|
||||
],
|
||||
extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
|
||||
is_python_module=False,
|
||||
)
|
||||
DISABLE_CUBLAS_GEMM = True
|
||||
if LoadPreCompileLibrary("wkv_cuda") is False:
|
||||
try:
|
||||
load(
|
||||
name=f"wkv_cuda",
|
||||
sources=[
|
||||
f"{current_path}/cuda/wrapper.cpp",
|
||||
f"{current_path}/cuda/operators.cu",
|
||||
f"{current_path}/cuda/gemm_fp16_cublas.cpp",
|
||||
],
|
||||
verbose=True,
|
||||
extra_ldflags=["cublas.lib"],
|
||||
extra_cuda_cflags=[
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"--extra-device-vectorization",
|
||||
],
|
||||
is_python_module=False,
|
||||
)
|
||||
DISABLE_CUBLAS_GEMM = False
|
||||
except:
|
||||
print(
|
||||
"Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow."
|
||||
)
|
||||
load(
|
||||
name=f"wkv_cuda",
|
||||
sources=[
|
||||
f"{current_path}/cuda/wrapper.cpp",
|
||||
f"{current_path}/cuda/operators.cu",
|
||||
],
|
||||
verbose=True,
|
||||
extra_cuda_cflags=[
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"--extra-device-vectorization",
|
||||
],
|
||||
extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
|
||||
is_python_module=False,
|
||||
)
|
||||
DISABLE_CUBLAS_GEMM = True
|
||||
|
||||
@MyStatic
|
||||
def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp):
|
||||
@ -527,24 +580,26 @@ class RWKV(MyModule):
|
||||
if self.version == 5.2:
|
||||
assert (
|
||||
os.environ["RWKV_CUDA_ON"] == "1"
|
||||
) # latest RWKV-5 requires os.environ["RWKV_CUDA_ON"] == '1' (will fix soon)
|
||||
), "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon)"
|
||||
HEAD_SIZE = args.n_att // args.n_head
|
||||
rwkv5 = load(
|
||||
name="rwkv5",
|
||||
sources=[
|
||||
f"{current_path}/cuda/rwkv5_op.cpp",
|
||||
f"{current_path}/cuda/rwkv5.cu",
|
||||
],
|
||||
verbose=True,
|
||||
extra_cuda_cflags=[
|
||||
"-res-usage",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"-Xptxas -O3",
|
||||
"--extra-device-vectorization",
|
||||
f"-D_N_={HEAD_SIZE}",
|
||||
],
|
||||
)
|
||||
if LoadPreCompileLibrary("rwkv5") is True:
|
||||
rwkv5 = torch.ops.rwkv5
|
||||
else:
|
||||
rwkv5 = load(
|
||||
name="rwkv5",
|
||||
sources=[
|
||||
f"{current_path}/cuda/rwkv5_op.cpp",
|
||||
f"{current_path}/cuda/rwkv5.cu",
|
||||
],
|
||||
verbose=True,
|
||||
extra_cuda_cflags=[
|
||||
"-res-usage",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"--extra-device-vectorization",
|
||||
f"-D_N_={HEAD_SIZE}",
|
||||
],
|
||||
)
|
||||
|
||||
class RWKV_5(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
Normal file
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
Normal file
Binary file not shown.
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
Normal file
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user