SwitchModelBody.customCuda
This commit is contained in:
parent
7989e93afe
commit
524d9e78e6
18
.vscode/launch.json
vendored
Normal file
18
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
//
|
||||||
|
// Use Ctrl+Shift+P to Select Interpreter
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "./backend-python/main.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
@ -2,6 +2,6 @@
|
|||||||
"[python]": {
|
"[python]": {
|
||||||
"editor.defaultFormatter": "ms-python.black-formatter"
|
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||||
},
|
},
|
||||||
"python.formatting.provider": "black",
|
"python.formatting.provider": "none",
|
||||||
"editor.formatOnSave": true
|
"editor.formatOnSave": true
|
||||||
}
|
}
|
@ -13,6 +13,7 @@ router = APIRouter()
|
|||||||
class SwitchModelBody(BaseModel):
|
class SwitchModelBody(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
strategy: str
|
strategy: str
|
||||||
|
customCuda: bool = False
|
||||||
|
|
||||||
|
|
||||||
@router.post("/switch-model")
|
@router.post("/switch-model")
|
||||||
@ -24,6 +25,8 @@ def switch_model(body: SwitchModelBody, response: Response):
|
|||||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
||||||
global_var.set(global_var.Model, None)
|
global_var.set(global_var.Model, None)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
os.environ["RWKV_CUDA_ON"] = '1' if body.customCuda else '0'
|
||||||
|
|
||||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
||||||
try:
|
try:
|
||||||
|
@ -36,8 +36,7 @@ def get_rwkv_config(model: RWKV) -> ModelConfigBody:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# os.environ["RWKV_CUDA_ON"] = '1'
|
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
||||||
# os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
|
||||||
|
|
||||||
|
|
||||||
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
|
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
|
||||||
|
Loading…
Reference in New Issue
Block a user