diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..9c2f2c0 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,18 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + // + // Use Ctrl+Shift+P to Select Interpreter + "version": "0.2.0", + "configurations": [ + { + "name": "Python", + "type": "python", + "request": "launch", + "program": "./backend-python/main.py", + "console": "integratedTerminal", + "justMyCode": false, + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 46bd03c..7984937 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,6 +2,6 @@ "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" }, - "python.formatting.provider": "black", + "python.formatting.provider": "none", "editor.formatOnSave": true } \ No newline at end of file diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index bd7a18b..41742ea 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -13,6 +13,7 @@ router = APIRouter() class SwitchModelBody(BaseModel): model: str strategy: str + customCuda: bool = False @router.post("/switch-model") @@ -24,6 +25,8 @@ def switch_model(body: SwitchModelBody, response: Response): global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline) global_var.set(global_var.Model, None) torch_gc() + + os.environ["RWKV_CUDA_ON"] = '1' if body.customCuda else '0' global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading) try: diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 7a5cb6a..77e28a7 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -36,8 +36,7 @@ def get_rwkv_config(model: RWKV) -> ModelConfigBody: ) -# os.environ["RWKV_CUDA_ON"] = '1' -# os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}" +os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}" def rwkv_generate(model: RWKV, prompt: str, stop: str = None):