From f56748a94133ea7fc7b9037a432b193696aa4981 Mon Sep 17 00:00:00 2001 From: josc146 Date: Tue, 25 Jul 2023 16:14:29 +0800 Subject: [PATCH] improve python backend startup speed --- backend-python/routes/state_cache.py | 6 ++++-- backend-python/utils/rwkv.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index ec6d81b..8c646b1 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -4,8 +4,6 @@ from fastapi import APIRouter, HTTPException, Request, Response, status from pydantic import BaseModel import gc import copy -import sys -import torch router = APIRouter() @@ -73,6 +71,8 @@ def add_state(body: AddStateBody): if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") + import torch + try: id: int = trie.insert(body.prompt) device: torch.device = body.state[0].device @@ -147,6 +147,8 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): if trie is None: raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") + import torch + id = -1 try: for id, len in trie.prefix(body.prompt): diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index ca54e1c..83a4c0b 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -7,9 +7,7 @@ from typing import Dict, Iterable, List, Tuple from utils.log import quick_log from fastapi import HTTPException from pydantic import BaseModel, Field -import torch import numpy as np -from rwkv_pip.utils import PIPELINE from routes import state_cache @@ -23,6 +21,7 @@ os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.res class AbstractRWKV(ABC): def __init__(self, model: str, strategy: str, tokens_path: str): from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work + from rwkv_pip.utils import PIPELINE filename, _ = os.path.splitext(os.path.basename(model)) self.name = filename @@ -75,6 +74,8 @@ class AbstractRWKV(ABC): return embedding, token_len def __fast_embedding(self, tokens: List[str], state): + import torch + tokens = [int(x) for x in tokens] token_len = len(tokens) self = self.model