better state cache
This commit is contained in:
@@ -49,13 +49,26 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||
if body.model == "":
|
||||
return "success"
|
||||
|
||||
if "->" in body.strategy:
|
||||
state_cache.disable_state_cache()
|
||||
else:
|
||||
try:
|
||||
state_cache.enable_state_cache()
|
||||
except HTTPException:
|
||||
pass
|
||||
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
|
||||
if not re.match(STRATEGY_REGEX, body.strategy):
|
||||
raise HTTPException(
|
||||
Status.HTTP_400_BAD_REQUEST,
|
||||
"Invalid strategy. Please read https://pypi.org/project/rwkv/",
|
||||
)
|
||||
devices = set(
|
||||
[
|
||||
x.strip().split(" ")[0].replace("cuda:0", "cuda")
|
||||
for x in body.strategy.split("->")
|
||||
]
|
||||
)
|
||||
print(f"Devices: {devices}")
|
||||
# if len(devices) > 1:
|
||||
# state_cache.disable_state_cache()
|
||||
# else:
|
||||
try:
|
||||
state_cache.enable_state_cache()
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user