improve state cache performance
This commit is contained in:
		
							parent
							
								
									e083f2c629
								
							
						
					
					
						commit
						7f3cfd54b0
					
				| @ -94,28 +94,19 @@ def add_state(body: AddStateBody): | |||||||
|         state: Union[Any, None] = None |         state: Union[Any, None] = None | ||||||
| 
 | 
 | ||||||
|         if body.state is not None: |         if body.state is not None: | ||||||
|             if type(body.state) == list or type(body.state) == np.ndarray: |             if type(state) == list and hasattr(state[0], "device"):  # torch | ||||||
|                 devices = [ |                 devices = [tensor.device for tensor in body.state] | ||||||
|                     ( |                 state = [tensor.cpu() for tensor in body.state] | ||||||
|                         tensor.device |             elif type(body.state) == np.ndarray:  # rwkv.cpp | ||||||
|                         if hasattr(tensor, "device") |                 state = body.state | ||||||
|                         else torch.device("cpu") |             else:  # WebGPU | ||||||
|                     ) |                 state = body.state.back() | ||||||
|                     for tensor in body.state |  | ||||||
|                 ] |  | ||||||
|                 state = ( |  | ||||||
|                     [tensor.cpu() for tensor in body.state] |  | ||||||
|                     if hasattr(body.state[0], "device") |  | ||||||
|                     else copy.deepcopy(body.state) |  | ||||||
|                 ) |  | ||||||
|             else: |  | ||||||
|                 state = body.state.back()  # WebGPU |  | ||||||
| 
 | 
 | ||||||
|         id: int = trie.insert(body.prompt) |         id: int = trie.insert(body.prompt) | ||||||
|         dtrie[id] = { |         dtrie[id] = { | ||||||
|             "tokens": copy.deepcopy(body.tokens), |             "tokens": body.tokens, | ||||||
|             "state": state, |             "state": state, | ||||||
|             "logits": copy.deepcopy(body.logits), |             "logits": body.logits, | ||||||
|             "devices": devices, |             "devices": devices, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| @ -199,12 +190,12 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): | |||||||
|     except: |     except: | ||||||
|         pass |         pass | ||||||
|     if id != -1: |     if id != -1: | ||||||
|  |         prompt: str = trie[id] | ||||||
|         v = dtrie[id] |         v = dtrie[id] | ||||||
|         devices: List[torch.device] = v["devices"] |         devices: List[torch.device] = v["devices"] | ||||||
|         prompt: str = trie[id] |  | ||||||
|         state: Union[Any, None] = v["state"] |         state: Union[Any, None] = v["state"] | ||||||
| 
 | 
 | ||||||
|         if state is not None and type(state) == list and hasattr(state[0], "device"): |         if type(state) == list and hasattr(state[0], "device"):  # torch | ||||||
|             state = [tensor.to(devices[i]) for i, tensor in enumerate(state)] |             state = [tensor.to(devices[i]) for i, tensor in enumerate(state)] | ||||||
| 
 | 
 | ||||||
|         quick_log(request, body, "Hit:\n" + prompt) |         quick_log(request, body, "Hit:\n" + prompt) | ||||||
|  | |||||||
| @ -239,14 +239,9 @@ class AbstractRWKV(ABC): | |||||||
|             self.model_tokens = [] |             self.model_tokens = [] | ||||||
|         else: |         else: | ||||||
|             delta_prompt = prompt[len(cache["prompt"]) :] |             delta_prompt = prompt[len(cache["prompt"]) :] | ||||||
|             state = cache["state"] |             self.model_state = cache["state"] | ||||||
|             self.model_state = ( |             self.model_tokens = cache["tokens"] | ||||||
|                 copy.deepcopy(state) |             logits = cache["logits"] | ||||||
|                 if type(state) == list or type(state) == np.ndarray |  | ||||||
|                 else state |  | ||||||
|             ) |  | ||||||
|             self.model_tokens = copy.deepcopy(cache["tokens"]) |  | ||||||
|             logits = copy.deepcopy(cache["logits"]) |  | ||||||
| 
 | 
 | ||||||
|         prompt_token_len = 0 |         prompt_token_len = 0 | ||||||
|         if delta_prompt != "": |         if delta_prompt != "": | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user