SwitchModelBody.customCuda

This commit is contained in:
josc146 2023-05-23 11:51:43 +08:00
parent 7989e93afe
commit 524d9e78e6
4 changed files with 23 additions and 3 deletions

18
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,18 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
//
// Use Ctrl+Shift+P to Select Interpreter
"version": "0.2.0",
"configurations": [
{
"name": "Python",
"type": "python",
"request": "launch",
"program": "./backend-python/main.py",
"console": "integratedTerminal",
"justMyCode": false,
}
]
}

View File

@ -2,6 +2,6 @@
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "black",
"python.formatting.provider": "none",
"editor.formatOnSave": true
}

View File

@ -13,6 +13,7 @@ router = APIRouter()
class SwitchModelBody(BaseModel):
model: str
strategy: str
customCuda: bool = False
@router.post("/switch-model")
@ -24,6 +25,8 @@ def switch_model(body: SwitchModelBody, response: Response):
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
global_var.set(global_var.Model, None)
torch_gc()
os.environ["RWKV_CUDA_ON"] = '1' if body.customCuda else '0'
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try:

View File

@ -36,8 +36,7 @@ def get_rwkv_config(model: RWKV) -> ModelConfigBody:
)
# os.environ["RWKV_CUDA_ON"] = '1'
# os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):