webgpu(python) state cache
This commit is contained in:
7
backend-python/rwkv_pip/webgpu/model.py
vendored
7
backend-python/rwkv_pip/webgpu/model.py
vendored
@@ -23,4 +23,9 @@ class RWKV:
|
||||
self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab
|
||||
|
||||
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
||||
return wrp.v5.run_one(self.model, tokens, state)
|
||||
if type(state).__name__ == "BackedState": # memory state
|
||||
gpu_state = wrp.v5.ModelState(self.model, 1)
|
||||
gpu_state.load(state)
|
||||
else:
|
||||
gpu_state = state
|
||||
return wrp.v5.run_one(self.model, tokens, gpu_state)
|
||||
|
||||
Reference in New Issue
Block a user