type
This commit is contained in:
		
							parent
							
								
									d32351c130
								
							
						
					
					
						commit
						377f71b16b
					
				| @ -48,8 +48,8 @@ def add_state(body: AddStateBody): | |||||||
|         raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") |         raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") | ||||||
| 
 | 
 | ||||||
|     try: |     try: | ||||||
|         id = trie.insert(body.prompt) |         id: int = trie.insert(body.prompt) | ||||||
|         device = body.state[0].device |         device: torch.device = body.state[0].device | ||||||
|         dtrie[id] = { |         dtrie[id] = { | ||||||
|             "tokens": copy.deepcopy(body.tokens), |             "tokens": copy.deepcopy(body.tokens), | ||||||
|             "state": [tensor.cpu() for tensor in body.state] |             "state": [tensor.cpu() for tensor in body.state] | ||||||
| @ -110,7 +110,7 @@ def _get_a_dtrie_buff_size(dtrie_v): | |||||||
|     # print(dtrie_v["logits"][0].element_size()) |     # print(dtrie_v["logits"][0].element_size()) | ||||||
|     # print(dtrie_v["logits"].nelement()) |     # print(dtrie_v["logits"].nelement()) | ||||||
|     # print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement()) |     # print(dtrie_v["logits"][0].element_size() * dtrie_v["logits"].nelement()) | ||||||
|     return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28 |     return 54 * len(dtrie_v["tokens"]) + 491520 + 262144 + 28  # TODO | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @router.post("/longest-prefix-state") | @router.post("/longest-prefix-state") | ||||||
| @ -127,8 +127,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): | |||||||
|         pass |         pass | ||||||
|     if id != -1: |     if id != -1: | ||||||
|         v = dtrie[id] |         v = dtrie[id] | ||||||
|         device = v["device"] |         device: torch.device = v["device"] | ||||||
|         prompt = trie[id] |         prompt: str = trie[id] | ||||||
|  | 
 | ||||||
|         quick_log(request, body, "Hit:\n" + prompt) |         quick_log(request, body, "Hit:\n" + prompt) | ||||||
|         return { |         return { | ||||||
|             "prompt": prompt, |             "prompt": prompt, | ||||||
| @ -137,7 +138,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): | |||||||
|             if device != torch.device("cpu") |             if device != torch.device("cpu") | ||||||
|             else v["state"], |             else v["state"], | ||||||
|             "logits": v["logits"], |             "logits": v["logits"], | ||||||
|             "device": device, |             "device": device.type, | ||||||
|         } |         } | ||||||
|     else: |     else: | ||||||
|         return { |         return { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user