improve python backend startup speed

This commit is contained in:
josc146 2023-07-25 16:14:29 +08:00
parent 29c5b1d804
commit f56748a941
2 changed files with 7 additions and 4 deletions

View File

@ -4,8 +4,6 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel
import gc
import copy
import sys
import torch
router = APIRouter()
@ -73,6 +71,8 @@ def add_state(body: AddStateBody):
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
import torch
try:
id: int = trie.insert(body.prompt)
device: torch.device = body.state[0].device
@ -147,6 +147,8 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
import torch
id = -1
try:
for id, len in trie.prefix(body.prompt):

View File

@ -7,9 +7,7 @@ from typing import Dict, Iterable, List, Tuple
from utils.log import quick_log
from fastapi import HTTPException
from pydantic import BaseModel, Field
import torch
import numpy as np
from rwkv_pip.utils import PIPELINE
from routes import state_cache
@ -23,6 +21,7 @@ os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.res
class AbstractRWKV(ABC):
def __init__(self, model: str, strategy: str, tokens_path: str):
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model))
self.name = filename
@ -75,6 +74,8 @@ class AbstractRWKV(ABC):
return embedding, token_len
def __fast_embedding(self, tokens: List[str], state):
import torch
tokens = [int(x) for x in tokens]
token_len = len(tokens)
self = self.model