better state cache
This commit is contained in:
parent
75244fbd8b
commit
d9e25ad69f
@ -49,13 +49,26 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
|||||||
if body.model == "":
|
if body.model == "":
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
if "->" in body.strategy:
|
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
|
||||||
state_cache.disable_state_cache()
|
if not re.match(STRATEGY_REGEX, body.strategy):
|
||||||
else:
|
raise HTTPException(
|
||||||
try:
|
Status.HTTP_400_BAD_REQUEST,
|
||||||
state_cache.enable_state_cache()
|
"Invalid strategy. Please read https://pypi.org/project/rwkv/",
|
||||||
except HTTPException:
|
)
|
||||||
pass
|
devices = set(
|
||||||
|
[
|
||||||
|
x.strip().split(" ")[0].replace("cuda:0", "cuda")
|
||||||
|
for x in body.strategy.split("->")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(f"Devices: {devices}")
|
||||||
|
# if len(devices) > 1:
|
||||||
|
# state_cache.disable_state_cache()
|
||||||
|
# else:
|
||||||
|
try:
|
||||||
|
state_cache.enable_state_cache()
|
||||||
|
except HTTPException:
|
||||||
|
pass
|
||||||
|
|
||||||
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ def disable_state_cache():
|
|||||||
dtrie = {}
|
dtrie = {}
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
print("state cache disabled")
|
||||||
return "success"
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
@ -61,8 +62,10 @@ def enable_state_cache():
|
|||||||
dtrie = {}
|
dtrie = {}
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
print("state cache enabled")
|
||||||
return "success"
|
return "success"
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
print("state cache disabled")
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "cyac not found")
|
||||||
|
|
||||||
|
|
||||||
@ -87,14 +90,12 @@ def add_state(body: AddStateBody):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
id: int = trie.insert(body.prompt)
|
id: int = trie.insert(body.prompt)
|
||||||
device: torch.device = body.state[0].device
|
devices: List[torch.device] = [tensor.device for tensor in body.state]
|
||||||
dtrie[id] = {
|
dtrie[id] = {
|
||||||
"tokens": copy.deepcopy(body.tokens),
|
"tokens": copy.deepcopy(body.tokens),
|
||||||
"state": [tensor.cpu() for tensor in body.state]
|
"state": [tensor.cpu() for tensor in body.state],
|
||||||
if device != torch.device("cpu")
|
|
||||||
else copy.deepcopy(body.state),
|
|
||||||
"logits": copy.deepcopy(body.logits),
|
"logits": copy.deepcopy(body.logits),
|
||||||
"device": device,
|
"devices": devices,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(trie) >= max_trie_len:
|
if len(trie) >= max_trie_len:
|
||||||
@ -177,27 +178,18 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
pass
|
pass
|
||||||
if id != -1:
|
if id != -1:
|
||||||
v = dtrie[id]
|
v = dtrie[id]
|
||||||
device: torch.device = v["device"]
|
devices: List[torch.device] = v["devices"]
|
||||||
prompt: str = trie[id]
|
prompt: str = trie[id]
|
||||||
|
|
||||||
quick_log(request, body, "Hit:\n" + prompt)
|
quick_log(request, body, "Hit:\n" + prompt)
|
||||||
return {
|
return {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"tokens": v["tokens"],
|
"tokens": v["tokens"],
|
||||||
"state": [tensor.to(device) for tensor in v["state"]]
|
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])],
|
||||||
if device != torch.device("cpu")
|
|
||||||
else v["state"],
|
|
||||||
"logits": v["logits"],
|
"logits": v["logits"],
|
||||||
"device": device.type,
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
||||||
"prompt": "",
|
|
||||||
"tokens": [],
|
|
||||||
"state": None,
|
|
||||||
"logits": None,
|
|
||||||
"device": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# @router.post("/save-state", tags=["State Cache"])
|
# @router.post("/save-state", tags=["State Cache"])
|
||||||
|
Loading…
Reference in New Issue
Block a user