improve python backend startup speed
This commit is contained in:
		
							parent
							
								
									29c5b1d804
								
							
						
					
					
						commit
						f56748a941
					
				@ -4,8 +4,6 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
import gc
 | 
			
		||||
import copy
 | 
			
		||||
import sys
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
@ -73,6 +71,8 @@ def add_state(body: AddStateBody):
 | 
			
		||||
    if trie is None:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
			
		||||
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        id: int = trie.insert(body.prompt)
 | 
			
		||||
        device: torch.device = body.state[0].device
 | 
			
		||||
@ -147,6 +147,8 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
 | 
			
		||||
    if trie is None:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
			
		||||
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
    id = -1
 | 
			
		||||
    try:
 | 
			
		||||
        for id, len in trie.prefix(body.prompt):
 | 
			
		||||
 | 
			
		||||
@ -7,9 +7,7 @@ from typing import Dict, Iterable, List, Tuple
 | 
			
		||||
from utils.log import quick_log
 | 
			
		||||
from fastapi import HTTPException
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
from rwkv_pip.utils import PIPELINE
 | 
			
		||||
from routes import state_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,6 +21,7 @@ os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.res
 | 
			
		||||
class AbstractRWKV(ABC):
 | 
			
		||||
    def __init__(self, model: str, strategy: str, tokens_path: str):
 | 
			
		||||
        from rwkv.model import RWKV as Model  # dynamic import to make RWKV_CUDA_ON work
 | 
			
		||||
        from rwkv_pip.utils import PIPELINE
 | 
			
		||||
 | 
			
		||||
        filename, _ = os.path.splitext(os.path.basename(model))
 | 
			
		||||
        self.name = filename
 | 
			
		||||
@ -75,6 +74,8 @@ class AbstractRWKV(ABC):
 | 
			
		||||
        return embedding, token_len
 | 
			
		||||
 | 
			
		||||
    def __fast_embedding(self, tokens: List[str], state):
 | 
			
		||||
        import torch
 | 
			
		||||
 | 
			
		||||
        tokens = [int(x) for x in tokens]
 | 
			
		||||
        token_len = len(tokens)
 | 
			
		||||
        self = self.model
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user