rwkv.cpp(ggml) support

This commit is contained in:
josc146
2023-12-12 20:29:55 +08:00
parent 6e29f97881
commit b14fbc29b7
26 changed files with 1234 additions and 102 deletions

View File

@@ -90,10 +90,15 @@ def add_state(body: AddStateBody):
try:
id: int = trie.insert(body.prompt)
devices: List[torch.device] = [tensor.device for tensor in body.state]
devices: List[torch.device] = [
(tensor.device if hasattr(tensor, "device") else torch.device("cpu"))
for tensor in body.state
]
dtrie[id] = {
"tokens": copy.deepcopy(body.tokens),
"state": [tensor.cpu() for tensor in body.state],
"state": [tensor.cpu() for tensor in body.state]
if hasattr(body.state[0], "device")
else copy.deepcopy(body.state),
"logits": copy.deepcopy(body.logits),
"devices": devices,
}
@@ -185,7 +190,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
return {
"prompt": prompt,
"tokens": v["tokens"],
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])],
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])]
if hasattr(v["state"][0], "device")
else v["state"],
"logits": v["logits"],
}
else: