improve python backend startup speed
This commit is contained in:
parent
29c5b1d804
commit
f56748a941
@ -4,8 +4,6 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import gc
|
import gc
|
||||||
import copy
|
import copy
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -73,6 +71,8 @@ def add_state(body: AddStateBody):
|
|||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
id: int = trie.insert(body.prompt)
|
id: int = trie.insert(body.prompt)
|
||||||
device: torch.device = body.state[0].device
|
device: torch.device = body.state[0].device
|
||||||
@ -147,6 +147,8 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
if trie is None:
|
if trie is None:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
id = -1
|
id = -1
|
||||||
try:
|
try:
|
||||||
for id, len in trie.prefix(body.prompt):
|
for id, len in trie.prefix(body.prompt):
|
||||||
|
@ -7,9 +7,7 @@ from typing import Dict, Iterable, List, Tuple
|
|||||||
from utils.log import quick_log
|
from utils.log import quick_log
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from rwkv_pip.utils import PIPELINE
|
|
||||||
from routes import state_cache
|
from routes import state_cache
|
||||||
|
|
||||||
|
|
||||||
@ -23,6 +21,7 @@ os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.res
|
|||||||
class AbstractRWKV(ABC):
|
class AbstractRWKV(ABC):
|
||||||
def __init__(self, model: str, strategy: str, tokens_path: str):
|
def __init__(self, model: str, strategy: str, tokens_path: str):
|
||||||
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
|
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
|
||||||
|
from rwkv_pip.utils import PIPELINE
|
||||||
|
|
||||||
filename, _ = os.path.splitext(os.path.basename(model))
|
filename, _ = os.path.splitext(os.path.basename(model))
|
||||||
self.name = filename
|
self.name = filename
|
||||||
@ -75,6 +74,8 @@ class AbstractRWKV(ABC):
|
|||||||
return embedding, token_len
|
return embedding, token_len
|
||||||
|
|
||||||
def __fast_embedding(self, tokens: List[str], state):
|
def __fast_embedding(self, tokens: List[str], state):
|
||||||
|
import torch
|
||||||
|
|
||||||
tokens = [int(x) for x in tokens]
|
tokens = [int(x) for x in tokens]
|
||||||
token_len = len(tokens)
|
token_len = len(tokens)
|
||||||
self = self.model
|
self = self.model
|
||||||
|
Loading…
Reference in New Issue
Block a user