diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py index c92f94d..81469b9 100644 --- a/backend-python/rwkv_pip/model.py +++ b/backend-python/rwkv_pip/model.py @@ -92,7 +92,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1": f"{current_path}/cuda/gemm_fp16_cublas.cpp", ], verbose=True, - extra_ldflags=["cublas.lib"], + extra_ldflags=["cublas.lib" if os.name == "nt" else ""], extra_cuda_cflags=[ "--use_fast_math", "-O3", @@ -596,6 +596,7 @@ class RWKV(MyModule): "-res-usage", "--use_fast_math", "-O3", + "-Xptxas -O3" if os.name != "nt" else "", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", ],