fix cross-device state cache exception

This commit is contained in:
josc146 2023-07-11 11:20:12 +08:00
parent b9a960d984
commit 994fc7c828
2 changed files with 36 additions and 0 deletions

View File

@ -52,6 +52,14 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
if body.model == "":
return "success"
if "->" in body.strategy:
state_cache.disable_state_cache()
else:
try:
state_cache.enable_state_cache()
except HTTPException:
pass
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)

View File

@ -34,6 +34,32 @@ def init():
print("cyac not found")
@router.post("/disable-state-cache")
def disable_state_cache():
global trie, dtrie
trie = None
dtrie = {}
gc.collect()
return "success"
@router.post("/enable-state-cache")
def enable_state_cache():
global trie, dtrie
try:
import cyac
trie = cyac.Trie()
dtrie = {}
gc.collect()
return "success"
except ModuleNotFoundError:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
class AddStateBody(BaseModel):
prompt: str
tokens: List[str]
@ -85,6 +111,8 @@ def reset_state():
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
import cyac
trie = cyac.Trie()
dtrie = {}
gc.collect()