better state cache

This commit is contained in:
josc146 2023-12-08 15:28:33 +08:00
parent 75244fbd8b
commit d9e25ad69f
2 changed files with 29 additions and 24 deletions

View File

@ -49,9 +49,22 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
if body.model == "": if body.model == "":
return "success" return "success"
if "->" in body.strategy: STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
state_cache.disable_state_cache() if not re.match(STRATEGY_REGEX, body.strategy):
else: raise HTTPException(
Status.HTTP_400_BAD_REQUEST,
"Invalid strategy. Please read https://pypi.org/project/rwkv/",
)
devices = set(
[
x.strip().split(" ")[0].replace("cuda:0", "cuda")
for x in body.strategy.split("->")
]
)
print(f"Devices: {devices}")
# if len(devices) > 1:
# state_cache.disable_state_cache()
# else:
try: try:
state_cache.enable_state_cache() state_cache.enable_state_cache()
except HTTPException: except HTTPException:

View File

@ -44,6 +44,7 @@ def disable_state_cache():
dtrie = {} dtrie = {}
gc.collect() gc.collect()
print("state cache disabled")
return "success" return "success"
@ -61,8 +62,10 @@ def enable_state_cache():
dtrie = {} dtrie = {}
gc.collect() gc.collect()
print("state cache enabled")
return "success" return "success"
except ModuleNotFoundError: except ModuleNotFoundError:
print("state cache disabled")
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found") raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
@ -87,14 +90,12 @@ def add_state(body: AddStateBody):
try: try:
id: int = trie.insert(body.prompt) id: int = trie.insert(body.prompt)
device: torch.device = body.state[0].device devices: List[torch.device] = [tensor.device for tensor in body.state]
dtrie[id] = { dtrie[id] = {
"tokens": copy.deepcopy(body.tokens), "tokens": copy.deepcopy(body.tokens),
"state": [tensor.cpu() for tensor in body.state] "state": [tensor.cpu() for tensor in body.state],
if device != torch.device("cpu")
else copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits), "logits": copy.deepcopy(body.logits),
"device": device, "devices": devices,
} }
if len(trie) >= max_trie_len: if len(trie) >= max_trie_len:
@ -177,27 +178,18 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
pass pass
if id != -1: if id != -1:
v = dtrie[id] v = dtrie[id]
device: torch.device = v["device"] devices: List[torch.device] = v["devices"]
prompt: str = trie[id] prompt: str = trie[id]
quick_log(request, body, "Hit:\n" + prompt) quick_log(request, body, "Hit:\n" + prompt)
return { return {
"prompt": prompt, "prompt": prompt,
"tokens": v["tokens"], "tokens": v["tokens"],
"state": [tensor.to(device) for tensor in v["state"]] "state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])],
if device != torch.device("cpu")
else v["state"],
"logits": v["logits"], "logits": v["logits"],
"device": device.type,
} }
else: else:
return { return {"prompt": "", "tokens": [], "state": None, "logits": None}
"prompt": "",
"tokens": [],
"state": None,
"logits": None,
"device": None,
}
# @router.post("/save-state", tags=["State Cache"]) # @router.post("/save-state", tags=["State Cache"])