From b41a2e70395958bd400a0fc4bdf21415952c0521 Mon Sep 17 00:00:00 2001 From: josc146 Date: Fri, 2 Jun 2023 21:33:57 +0800 Subject: [PATCH] move state cache to memory (todo: state cache db) --- README.md | 1 + README_ZH.md | 1 + backend-python/routes/state_cache.py | 13 +++++++++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b67e78a..6dbabf4 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ body.json: - [x] CUDA operator int8 acceleration - [ ] macOS support - [ ] Linux support +- [ ] Local State Cache DB ## Related Repositories: diff --git a/README_ZH.md b/README_ZH.md index 589bdb5..ea51beb 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -73,6 +73,7 @@ body.json: - [x] CUDA算子int8提速 - [ ] macOS支持 - [ ] linux支持 +- [ ] 本地状态缓存数据库 ## 相关仓库: diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 16e5404..deb1047 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, HTTPException, Response, status from pydantic import BaseModel import gc import copy +import torch router = APIRouter() @@ -41,10 +42,14 @@ def add_state(body: AddStateBody): raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") id = trie.insert(body.prompt) + device = body.state[0].device dtrie[id] = { "tokens": copy.deepcopy(body.tokens), - "state": copy.deepcopy(body.state), + "state": [tensor.cpu() for tensor in body.state] + if device != torch.device("cpu") + else copy.deepcopy(body.state), "logits": copy.deepcopy(body.logits), + "device": device, } return "success" @@ -77,11 +82,15 @@ def longest_prefix_state(body: LongestPrefixStateBody): pass if id != -1: v = dtrie[id] + device = v["device"] return { "prompt": trie[id], "tokens": v["tokens"], - "state": v["state"], + "state": [tensor.to(device) for tensor in v["state"]] + if device != torch.device("cpu") + else v["state"], "logits": v["logits"], + "device": device, } else: return {"prompt": "", "tokens": [], "state": None, "logits": None}