SwitchModelBody.customCuda

This commit is contained in:
josc146
2023-05-23 11:51:43 +08:00
parent 7989e93afe
commit 524d9e78e6
4 changed files with 23 additions and 3 deletions

View File

@@ -13,6 +13,7 @@ router = APIRouter()
class SwitchModelBody(BaseModel):
model: str
strategy: str
customCuda: bool = False
@router.post("/switch-model")
@@ -24,6 +25,8 @@ def switch_model(body: SwitchModelBody, response: Response):
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
global_var.set(global_var.Model, None)
torch_gc()
os.environ["RWKV_CUDA_ON"] = '1' if body.customCuda else '0'
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try: