webgpu(python) state cache

This commit is contained in:
josc146 2023-12-28 20:43:57 +08:00
parent e33858f110
commit e083f2c629
3 changed files with 13 additions and 3 deletions

View File

@ -109,7 +109,7 @@ def add_state(body: AddStateBody):
else copy.deepcopy(body.state) else copy.deepcopy(body.state)
) )
else: else:
pass # WebGPU state = body.state.back() # WebGPU
id: int = trie.insert(body.prompt) id: int = trie.insert(body.prompt)
dtrie[id] = { dtrie[id] = {

View File

@ -23,4 +23,9 @@ class RWKV:
self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab
def forward(self, tokens: List[int], state: Union[Any, None] = None): def forward(self, tokens: List[int], state: Union[Any, None] = None):
return wrp.v5.run_one(self.model, tokens, state) if type(state).__name__ == "BackedState": # memory state
gpu_state = wrp.v5.ModelState(self.model, 1)
gpu_state.load(state)
else:
gpu_state = state
return wrp.v5.run_one(self.model, tokens, gpu_state)

View File

@ -239,7 +239,12 @@ class AbstractRWKV(ABC):
self.model_tokens = [] self.model_tokens = []
else: else:
delta_prompt = prompt[len(cache["prompt"]) :] delta_prompt = prompt[len(cache["prompt"]) :]
self.model_state = copy.deepcopy(cache["state"]) state = cache["state"]
self.model_state = (
copy.deepcopy(state)
if type(state) == list or type(state) == np.ndarray
else state
)
self.model_tokens = copy.deepcopy(cache["tokens"]) self.model_tokens = copy.deepcopy(cache["tokens"])
logits = copy.deepcopy(cache["logits"]) logits = copy.deepcopy(cache["logits"])