SwitchModelBody.customCuda

This commit is contained in:
josc146
2023-05-23 11:51:43 +08:00
parent 7989e93afe
commit 524d9e78e6
4 changed files with 23 additions and 3 deletions

View File

@@ -36,8 +36,7 @@ def get_rwkv_config(model: RWKV) -> ModelConfigBody:
)
# os.environ["RWKV_CUDA_ON"] = '1'
# os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):