max_trie_len
This commit is contained in:
		
							parent
							
								
									8431b5d24f
								
							
						
					
					
						commit
						5896593951
					
				@ -4,12 +4,16 @@ from fastapi import APIRouter, HTTPException, Request, Response, status
 | 
				
			|||||||
from pydantic import BaseModel
 | 
					from pydantic import BaseModel
 | 
				
			||||||
import gc
 | 
					import gc
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
router = APIRouter()
 | 
					router = APIRouter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
trie = None
 | 
					trie = None
 | 
				
			||||||
dtrie: Dict = {}
 | 
					dtrie: Dict = {}
 | 
				
			||||||
 | 
					max_trie_len = 3000
 | 
				
			||||||
 | 
					loop_start_id = 1  # to prevent preloaded prompts from being deleted
 | 
				
			||||||
 | 
					loop_del_trie_id = loop_start_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def init():
 | 
					def init():
 | 
				
			||||||
@ -39,10 +43,18 @@ class AddStateBody(BaseModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
@router.post("/add-state")
 | 
					@router.post("/add-state")
 | 
				
			||||||
def add_state(body: AddStateBody):
 | 
					def add_state(body: AddStateBody):
 | 
				
			||||||
    global trie, dtrie
 | 
					    global trie, dtrie, loop_del_trie_id
 | 
				
			||||||
    if trie is None:
 | 
					    if trie is None:
 | 
				
			||||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
					        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if len(trie) >= max_trie_len:
 | 
				
			||||||
 | 
					        del_prompt = trie[loop_del_trie_id]
 | 
				
			||||||
 | 
					        trie.remove(del_prompt)
 | 
				
			||||||
 | 
					        dtrie[loop_del_trie_id] = None
 | 
				
			||||||
 | 
					        loop_del_trie_id = loop_del_trie_id + 1
 | 
				
			||||||
 | 
					        if loop_del_trie_id >= max_trie_len:
 | 
				
			||||||
 | 
					            loop_del_trie_id = loop_start_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    id = trie.insert(body.prompt)
 | 
					    id = trie.insert(body.prompt)
 | 
				
			||||||
    device = body.state[0].device
 | 
					    device = body.state[0].device
 | 
				
			||||||
    dtrie[id] = {
 | 
					    dtrie[id] = {
 | 
				
			||||||
@ -54,6 +66,12 @@ def add_state(body: AddStateBody):
 | 
				
			|||||||
        "device": device,
 | 
					        "device": device,
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    quick_log(
 | 
				
			||||||
 | 
					        None,
 | 
				
			||||||
 | 
					        None,
 | 
				
			||||||
 | 
					        f"New Trie Id: {id}\nTrie Len: {len(trie)}\nTrie Buff Size: {trie.buff_size()}\nDtrie Buff Size Of Id: {_get_a_dtrie_buff_size(dtrie[id])}",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return "success"
 | 
					    return "success"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -74,6 +92,23 @@ class LongestPrefixStateBody(BaseModel):
 | 
				
			|||||||
    prompt: str
 | 
					    prompt: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_a_dtrie_buff_size(dtrie_v):
 | 
				
			||||||
 | 
					    # print(sys.getsizeof(dtrie_v["tokens"][0]))  # str
 | 
				
			||||||
 | 
					    # print(sys.getsizeof(dtrie_v["tokens"][0]) * len(dtrie_v["tokens"]))
 | 
				
			||||||
 | 
					    # print(dtrie_v["state"][0][0].element_size())
 | 
				
			||||||
 | 
					    # print(dtrie_v["state"][0].nelement())
 | 
				
			||||||
 | 
					    # print(len(dtrie_v["state"]))
 | 
				
			||||||
 | 
					    # print(
 | 
				
			||||||
 | 
					    #     len(dtrie_v["state"])
 | 
				
			||||||
 | 
					    #     * dtrie_v["state"][0].nelement()
 | 
				
			||||||
 | 
					    #     * dtrie_v["state"][0][0].element_size()
 | 
				
			||||||
 | 
					    # )
 | 
				
			||||||
 | 
					    # print(dtrie_v["logits"][0].element_size())
 | 
				
			||||||
 | 
					    # print(dtrie_v["logits"].nelement())
 | 
				
			||||||
 | 
					    # print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement())
 | 
				
			||||||
 | 
					    return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@router.post("/longest-prefix-state")
 | 
					@router.post("/longest-prefix-state")
 | 
				
			||||||
def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
 | 
					def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
 | 
				
			||||||
    global trie
 | 
					    global trie
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user