rwkv5 pre-compiled kernel (for windows)
This commit is contained in:
parent
79851433f8
commit
68228a4552
165
backend-python/rwkv_pip/model.py
vendored
165
backend-python/rwkv_pip/model.py
vendored
@ -12,6 +12,56 @@ torch.backends.cudnn.allow_tf32 = True
|
|||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
current_path = os.path.dirname(os.path.abspath(__file__))
|
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
# https://zhuanlan.zhihu.com/p/612879065
|
||||||
|
def LoadPreCompileLibrary(file):
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# load the custom_op_library and register the custom ops
|
||||||
|
lib_dir = os.path.dirname(__file__)
|
||||||
|
if os.name == "nt":
|
||||||
|
# Register the main torchvision library location on the default DLL path
|
||||||
|
import ctypes
|
||||||
|
import sys
|
||||||
|
|
||||||
|
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
||||||
|
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
||||||
|
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
||||||
|
|
||||||
|
if with_load_library_flags:
|
||||||
|
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 8):
|
||||||
|
os.add_dll_directory(lib_dir)
|
||||||
|
elif with_load_library_flags:
|
||||||
|
res = kernel32.AddDllDirectory(lib_dir)
|
||||||
|
if res is None:
|
||||||
|
err = ctypes.WinError(ctypes.get_last_error())
|
||||||
|
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
|
||||||
|
raise ValueError(err)
|
||||||
|
|
||||||
|
kernel32.SetErrorMode(prev_error_mode)
|
||||||
|
|
||||||
|
loader_details = (
|
||||||
|
importlib.machinery.ExtensionFileLoader,
|
||||||
|
importlib.machinery.EXTENSION_SUFFIXES,
|
||||||
|
)
|
||||||
|
|
||||||
|
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
||||||
|
ext_specs = extfinder.find_spec(file)
|
||||||
|
if ext_specs is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.ops.load_library(ext_specs.origin)
|
||||||
|
except OSError as exc:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
########################################################################################################
|
########################################################################################################
|
||||||
|
|
||||||
if os.environ.get("RWKV_JIT_ON") != "0":
|
if os.environ.get("RWKV_JIT_ON") != "0":
|
||||||
@ -29,45 +79,48 @@ else:
|
|||||||
MyStatic = __nop
|
MyStatic = __nop
|
||||||
|
|
||||||
if os.environ.get("RWKV_CUDA_ON") == "1":
|
if os.environ.get("RWKV_CUDA_ON") == "1":
|
||||||
from torch.utils.cpp_extension import load
|
DISABLE_CUBLAS_GEMM = False
|
||||||
|
from torch.utils.cpp_extension import load # L581
|
||||||
|
|
||||||
try:
|
if LoadPreCompileLibrary("wkv_cuda") is False:
|
||||||
load(
|
try:
|
||||||
name=f"wkv_cuda",
|
load(
|
||||||
sources=[
|
name=f"wkv_cuda",
|
||||||
f"{current_path}/cuda/wrapper.cpp",
|
sources=[
|
||||||
f"{current_path}/cuda/operators.cu",
|
f"{current_path}/cuda/wrapper.cpp",
|
||||||
f"{current_path}/cuda/gemm_fp16_cublas.cpp",
|
f"{current_path}/cuda/operators.cu",
|
||||||
],
|
f"{current_path}/cuda/gemm_fp16_cublas.cpp",
|
||||||
verbose=True,
|
],
|
||||||
extra_cuda_cflags=[
|
verbose=True,
|
||||||
"--use_fast_math",
|
extra_ldflags=["cublas.lib"],
|
||||||
"-O3",
|
extra_cuda_cflags=[
|
||||||
"--extra-device-vectorization",
|
"--use_fast_math",
|
||||||
],
|
"-O3",
|
||||||
is_python_module=False,
|
"--extra-device-vectorization",
|
||||||
)
|
],
|
||||||
DISABLE_CUBLAS_GEMM = False
|
is_python_module=False,
|
||||||
except:
|
)
|
||||||
print(
|
DISABLE_CUBLAS_GEMM = False
|
||||||
"Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow."
|
except:
|
||||||
)
|
print(
|
||||||
load(
|
"Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow."
|
||||||
name=f"wkv_cuda",
|
)
|
||||||
sources=[
|
load(
|
||||||
f"{current_path}/cuda/wrapper.cpp",
|
name=f"wkv_cuda",
|
||||||
f"{current_path}/cuda/operators.cu",
|
sources=[
|
||||||
],
|
f"{current_path}/cuda/wrapper.cpp",
|
||||||
verbose=True,
|
f"{current_path}/cuda/operators.cu",
|
||||||
extra_cuda_cflags=[
|
],
|
||||||
"--use_fast_math",
|
verbose=True,
|
||||||
"-O3",
|
extra_cuda_cflags=[
|
||||||
"--extra-device-vectorization",
|
"--use_fast_math",
|
||||||
],
|
"-O3",
|
||||||
extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
|
"--extra-device-vectorization",
|
||||||
is_python_module=False,
|
],
|
||||||
)
|
extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
|
||||||
DISABLE_CUBLAS_GEMM = True
|
is_python_module=False,
|
||||||
|
)
|
||||||
|
DISABLE_CUBLAS_GEMM = True
|
||||||
|
|
||||||
@MyStatic
|
@MyStatic
|
||||||
def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp):
|
def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp):
|
||||||
@ -527,24 +580,26 @@ class RWKV(MyModule):
|
|||||||
if self.version == 5.2:
|
if self.version == 5.2:
|
||||||
assert (
|
assert (
|
||||||
os.environ["RWKV_CUDA_ON"] == "1"
|
os.environ["RWKV_CUDA_ON"] == "1"
|
||||||
) # latest RWKV-5 requires os.environ["RWKV_CUDA_ON"] == '1' (will fix soon)
|
), "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon)"
|
||||||
HEAD_SIZE = args.n_att // args.n_head
|
HEAD_SIZE = args.n_att // args.n_head
|
||||||
rwkv5 = load(
|
if LoadPreCompileLibrary("rwkv5") is True:
|
||||||
name="rwkv5",
|
rwkv5 = torch.ops.rwkv5
|
||||||
sources=[
|
else:
|
||||||
f"{current_path}/cuda/rwkv5_op.cpp",
|
rwkv5 = load(
|
||||||
f"{current_path}/cuda/rwkv5.cu",
|
name="rwkv5",
|
||||||
],
|
sources=[
|
||||||
verbose=True,
|
f"{current_path}/cuda/rwkv5_op.cpp",
|
||||||
extra_cuda_cflags=[
|
f"{current_path}/cuda/rwkv5.cu",
|
||||||
"-res-usage",
|
],
|
||||||
"--use_fast_math",
|
verbose=True,
|
||||||
"-O3",
|
extra_cuda_cflags=[
|
||||||
"-Xptxas -O3",
|
"-res-usage",
|
||||||
"--extra-device-vectorization",
|
"--use_fast_math",
|
||||||
f"-D_N_={HEAD_SIZE}",
|
"-O3",
|
||||||
],
|
"--extra-device-vectorization",
|
||||||
)
|
f"-D_N_={HEAD_SIZE}",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
class RWKV_5(torch.autograd.Function):
|
class RWKV_5(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
Normal file
BIN
backend-python/rwkv_pip/rwkv5.pyd
vendored
Normal file
Binary file not shown.
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
Normal file
BIN
backend-python/rwkv_pip/wkv_cuda.pyd
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user