rwkv.cpp(ggml) support
This commit is contained in:
@@ -49,19 +49,13 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
||||
if body.model == "":
|
||||
return "success"
|
||||
|
||||
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
|
||||
if not re.match(STRATEGY_REGEX, body.strategy):
|
||||
raise HTTPException(
|
||||
Status.HTTP_400_BAD_REQUEST,
|
||||
"Invalid strategy. Please read https://pypi.org/project/rwkv/",
|
||||
)
|
||||
devices = set(
|
||||
[
|
||||
x.strip().split(" ")[0].replace("cuda:0", "cuda")
|
||||
for x in body.strategy.split("->")
|
||||
]
|
||||
)
|
||||
print(f"Devices: {devices}")
|
||||
print(f"Strategy Devices: {devices}")
|
||||
# if len(devices) > 1:
|
||||
# state_cache.disable_state_cache()
|
||||
# else:
|
||||
|
||||
@@ -90,10 +90,15 @@ def add_state(body: AddStateBody):
|
||||
|
||||
try:
|
||||
id: int = trie.insert(body.prompt)
|
||||
devices: List[torch.device] = [tensor.device for tensor in body.state]
|
||||
devices: List[torch.device] = [
|
||||
(tensor.device if hasattr(tensor, "device") else torch.device("cpu"))
|
||||
for tensor in body.state
|
||||
]
|
||||
dtrie[id] = {
|
||||
"tokens": copy.deepcopy(body.tokens),
|
||||
"state": [tensor.cpu() for tensor in body.state],
|
||||
"state": [tensor.cpu() for tensor in body.state]
|
||||
if hasattr(body.state[0], "device")
|
||||
else copy.deepcopy(body.state),
|
||||
"logits": copy.deepcopy(body.logits),
|
||||
"devices": devices,
|
||||
}
|
||||
@@ -185,7 +190,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"tokens": v["tokens"],
|
||||
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])],
|
||||
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])]
|
||||
if hasattr(v["state"][0], "device")
|
||||
else v["state"],
|
||||
"logits": v["logits"],
|
||||
}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user