webgpu(python) state cache
This commit is contained in:
parent
e33858f110
commit
e083f2c629
@ -109,7 +109,7 @@ def add_state(body: AddStateBody):
|
|||||||
else copy.deepcopy(body.state)
|
else copy.deepcopy(body.state)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pass # WebGPU
|
state = body.state.back() # WebGPU
|
||||||
|
|
||||||
id: int = trie.insert(body.prompt)
|
id: int = trie.insert(body.prompt)
|
||||||
dtrie[id] = {
|
dtrie[id] = {
|
||||||
|
7
backend-python/rwkv_pip/webgpu/model.py
vendored
7
backend-python/rwkv_pip/webgpu/model.py
vendored
@ -23,4 +23,9 @@ class RWKV:
|
|||||||
self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab
|
self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab
|
||||||
|
|
||||||
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
||||||
return wrp.v5.run_one(self.model, tokens, state)
|
if type(state).__name__ == "BackedState": # memory state
|
||||||
|
gpu_state = wrp.v5.ModelState(self.model, 1)
|
||||||
|
gpu_state.load(state)
|
||||||
|
else:
|
||||||
|
gpu_state = state
|
||||||
|
return wrp.v5.run_one(self.model, tokens, gpu_state)
|
||||||
|
@ -239,7 +239,12 @@ class AbstractRWKV(ABC):
|
|||||||
self.model_tokens = []
|
self.model_tokens = []
|
||||||
else:
|
else:
|
||||||
delta_prompt = prompt[len(cache["prompt"]) :]
|
delta_prompt = prompt[len(cache["prompt"]) :]
|
||||||
self.model_state = copy.deepcopy(cache["state"])
|
state = cache["state"]
|
||||||
|
self.model_state = (
|
||||||
|
copy.deepcopy(state)
|
||||||
|
if type(state) == list or type(state) == np.ndarray
|
||||||
|
else state
|
||||||
|
)
|
||||||
self.model_tokens = copy.deepcopy(cache["tokens"])
|
self.model_tokens = copy.deepcopy(cache["tokens"])
|
||||||
logits = copy.deepcopy(cache["logits"])
|
logits = copy.deepcopy(cache["logits"])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user