max_trie_len
This commit is contained in:
		
							parent
							
								
									8431b5d24f
								
							
						
					
					
						commit
						5896593951
					
				@ -4,12 +4,16 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
import gc
 | 
			
		||||
import copy
 | 
			
		||||
import sys
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
trie = None
 | 
			
		||||
dtrie: Dict = {}
 | 
			
		||||
max_trie_len = 3000
 | 
			
		||||
loop_start_id = 1  # to prevent preloaded prompts from being deleted
 | 
			
		||||
loop_del_trie_id = loop_start_id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init():
 | 
			
		||||
@ -39,10 +43,18 @@ class AddStateBody(BaseModel):
 | 
			
		||||
 | 
			
		||||
@router.post("/add-state")
 | 
			
		||||
def add_state(body: AddStateBody):
 | 
			
		||||
    global trie, dtrie
 | 
			
		||||
    global trie, dtrie, loop_del_trie_id
 | 
			
		||||
    if trie is None:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
			
		||||
 | 
			
		||||
    if len(trie) >= max_trie_len:
 | 
			
		||||
        del_prompt = trie[loop_del_trie_id]
 | 
			
		||||
        trie.remove(del_prompt)
 | 
			
		||||
        dtrie[loop_del_trie_id] = None
 | 
			
		||||
        loop_del_trie_id = loop_del_trie_id + 1
 | 
			
		||||
        if loop_del_trie_id >= max_trie_len:
 | 
			
		||||
            loop_del_trie_id = loop_start_id
 | 
			
		||||
 | 
			
		||||
    id = trie.insert(body.prompt)
 | 
			
		||||
    device = body.state[0].device
 | 
			
		||||
    dtrie[id] = {
 | 
			
		||||
@ -54,6 +66,12 @@ def add_state(body: AddStateBody):
 | 
			
		||||
        "device": device,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    quick_log(
 | 
			
		||||
        None,
 | 
			
		||||
        None,
 | 
			
		||||
        f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return "success"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -74,6 +92,23 @@ class LongestPrefixStateBody(BaseModel):
 | 
			
		||||
    prompt: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_a_dtrie_buff_size(dtrie_v):
 | 
			
		||||
    # print(sys.getsizeof(dtrie_v["tokens"][0]))  # str
 | 
			
		||||
    # print(sys.getsizeof(dtrie_v["tokens"][0]) * len(dtrie_v["tokens"]))
 | 
			
		||||
    # print(dtrie_v["state"][0][0].element_size())
 | 
			
		||||
    # print(dtrie_v["state"][0].nelement())
 | 
			
		||||
    # print(len(dtrie_v["state"]))
 | 
			
		||||
    # print(
 | 
			
		||||
    #     len(dtrie_v["state"])
 | 
			
		||||
    #     * dtrie_v["state"][0].nelement()
 | 
			
		||||
    #     * dtrie_v["state"][0][0].element_size()
 | 
			
		||||
    # )
 | 
			
		||||
    # print(dtrie_v["logits"][0].element_size())
 | 
			
		||||
    # print(dtrie_v["logits"].nelement())
 | 
			
		||||
    # print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
 | 
			
		||||
    return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/longest-prefix-state")
 | 
			
		||||
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
 | 
			
		||||
    global trie
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user