rwkv.cpp(ggml) support

This commit is contained in:
josc146
2023-12-12 20:29:55 +08:00
parent 6e29f97881
commit b14fbc29b7
26 changed files with 1234 additions and 102 deletions

View File

@@ -49,19 +49,13 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
if body.model == "":
return "success"
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
if not re.match(STRATEGY_REGEX, body.strategy):
raise HTTPException(
Status.HTTP_400_BAD_REQUEST,
"Invalid strategy. Please read https://pypi.org/project/rwkv/",
)
devices = set(
[
x.strip().split(" ")[0].replace("cuda:0", "cuda")
for x in body.strategy.split("->")
]
)
print(f"Devices: {devices}")
print(f"Strategy Devices: {devices}")
# if len(devices) > 1:
# state_cache.disable_state_cache()
# else:

View File

@@ -90,10 +90,15 @@ def add_state(body: AddStateBody):
try:
id: int = trie.insert(body.prompt)
devices: List[torch.device] = [tensor.device for tensor in body.state]
devices: List[torch.device] = [
(tensor.device if hasattr(tensor, "device") else torch.device("cpu"))
for tensor in body.state
]
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"state": [tensor.cpu() for tensor in body.state],
"state": [tensor.cpu() for tensor in body.state]
if hasattr(body.state[0], "device")
else copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits),
"devices": devices,
}
@@ -185,7 +190,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
return {
"prompt": prompt,
"tokens": v["tokens"],
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])],
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])]
if hasattr(v["state"][0], "device")
else v["state"],
"logits": v["logits"],
}
else: