better state cache

This commit is contained in:
josc146
2023-12-08 15:28:33 +08:00
parent 75244fbd8b
commit d9e25ad69f
2 changed files with 29 additions and 24 deletions

View File

@@ -49,13 +49,26 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
if body.model == "":
return "success"
if "->" in body.strategy:
state_cache.disable_state_cache()
else:
try:
state_cache.enable_state_cache()
except HTTPException:
pass
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
if not re.match(STRATEGY_REGEX, body.strategy):
raise HTTPException(
Status.HTTP_400_BAD_REQUEST,
"Invalid strategy. Please read https://pypi.org/project/rwkv/",
)
devices = set(
[
x.strip().split(" ")[0].replace("cuda:0", "cuda")
for x in body.strategy.split("->")
]
)
print(f"Devices: {devices}")
# if len(devices) > 1:
# state_cache.disable_state_cache()
# else:
try:
state_cache.enable_state_cache()
except HTTPException:
pass
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"