From e083f2c6294995d6872bafd7428b8cc709d1533e Mon Sep 17 00:00:00 2001 From: josc146 Date: Thu, 28 Dec 2023 20:43:57 +0800 Subject: [PATCH] webgpu(python) state cache --- backend-python/routes/state_cache.py | 2 +- backend-python/rwkv_pip/webgpu/model.py | 7 ++++++- backend-python/utils/rwkv.py | 7 ++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 33bf74f..647ffac 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -109,7 +109,7 @@ def add_state(body: AddStateBody): else copy.deepcopy(body.state) ) else: - pass # WebGPU + state = body.state.back() # WebGPU id: int = trie.insert(body.prompt) dtrie[id] = { diff --git a/backend-python/rwkv_pip/webgpu/model.py b/backend-python/rwkv_pip/webgpu/model.py index 5d65344..07df831 100644 --- a/backend-python/rwkv_pip/webgpu/model.py +++ b/backend-python/rwkv_pip/webgpu/model.py @@ -23,4 +23,9 @@ class RWKV: self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab def forward(self, tokens: List[int], state: Union[Any, None] = None): - return wrp.v5.run_one(self.model, tokens, state) + if type(state).__name__ == "BackedState": # memory state + gpu_state = wrp.v5.ModelState(self.model, 1) + gpu_state.load(state) + else: + gpu_state = state + return wrp.v5.run_one(self.model, tokens, gpu_state) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 8414ed3..9109c47 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -239,7 +239,12 @@ class AbstractRWKV(ABC): self.model_tokens = [] else: delta_prompt = prompt[len(cache["prompt"]) :] - self.model_state = copy.deepcopy(cache["state"]) + state = cache["state"] + self.model_state = ( + copy.deepcopy(state) + if type(state) == list or type(state) == np.ndarray + else state + ) self.model_tokens = copy.deepcopy(cache["tokens"]) logits = copy.deepcopy(cache["logits"])