improve python backend startup speed

This commit is contained in:
josc146 2023-07-25 16:14:29 +08:00
parent 29c5b1d804
commit f56748a941
2 changed files with 7 additions and 4 deletions

View File

@ -4,8 +4,6 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
from pydantic import BaseModel from pydantic import BaseModel
import gc import gc
import copy import copy
import sys
import torch
router = APIRouter() router = APIRouter()
@ -73,6 +71,8 @@ def add_state(body: AddStateBody):
if trie is None: if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
import torch
try: try:
id: int = trie.insert(body.prompt) id: int = trie.insert(body.prompt)
device: torch.device = body.state[0].device device: torch.device = body.state[0].device
@ -147,6 +147,8 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
if trie is None: if trie is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
import torch
id = -1 id = -1
try: try:
for id, len in trie.prefix(body.prompt): for id, len in trie.prefix(body.prompt):

View File

@ -7,9 +7,7 @@ from typing import Dict, Iterable, List, Tuple
from utils.log import quick_log from utils.log import quick_log
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import torch
import numpy as np import numpy as np
from rwkv_pip.utils import PIPELINE
from routes import state_cache from routes import state_cache
@ -23,6 +21,7 @@ os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.res
class AbstractRWKV(ABC): class AbstractRWKV(ABC):
def __init__(self, model: str, strategy: str, tokens_path: str): def __init__(self, model: str, strategy: str, tokens_path: str):
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model)) filename, _ = os.path.splitext(os.path.basename(model))
self.name = filename self.name = filename
@ -75,6 +74,8 @@ class AbstractRWKV(ABC):
return embedding, token_len return embedding, token_len
def __fast_embedding(self, tokens: List[str], state): def __fast_embedding(self, tokens: List[str], state):
import torch
tokens = [int(x) for x in tokens] tokens = [int(x) for x in tokens]
token_len = len(tokens) token_len = len(tokens)
self = self.model self = self.model