fix cross-device state cache exception
This commit is contained in:
parent
b9a960d984
commit
994fc7c828
@ -52,6 +52,14 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
|||||||
if body.model == "":
|
if body.model == "":
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
if "->" in body.strategy:
|
||||||
|
state_cache.disable_state_cache()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
state_cache.enable_state_cache()
|
||||||
|
except HTTPException:
|
||||||
|
pass
|
||||||
|
|
||||||
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
||||||
|
|
||||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
||||||
|
@ -34,6 +34,32 @@ def init():
|
|||||||
print("cyac not found")
|
print("cyac not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/disable-state-cache")
|
||||||
|
def disable_state_cache():
|
||||||
|
global trie, dtrie
|
||||||
|
|
||||||
|
trie = None
|
||||||
|
dtrie = {}
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/enable-state-cache")
|
||||||
|
def enable_state_cache():
|
||||||
|
global trie, dtrie
|
||||||
|
try:
|
||||||
|
import cyac
|
||||||
|
|
||||||
|
trie = cyac.Trie()
|
||||||
|
dtrie = {}
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return "success"
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
|
||||||
|
|
||||||
|
|
||||||
class AddStateBody(BaseModel):
|
class AddStateBody(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
tokens: List[str]
|
tokens: List[str]
|
||||||
@ -85,6 +111,8 @@ def reset_state():
|
|||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
|
import cyac
|
||||||
|
|
||||||
trie = cyac.Trie()
|
trie = cyac.Trie()
|
||||||
dtrie = {}
|
dtrie = {}
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
Loading…
Reference in New Issue
Block a user