diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 810676d..cae32ad 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -96,7 +96,9 @@ def copy_tensor_to_cpu(tensors): elif tensors_type == np.ndarray: # rwkv.cpp copied = tensors else: # WebGPU state - copied = tensors.back() + model = global_var.get(global_var.Model) + if model: + copied = model.model.model.back_state() return copied, devices @@ -238,11 +240,14 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): state: Union[Any, None] = v["state"] logits: Union[Any, None] = v["logits"] - if type(state) == list and hasattr(state[0], "device"): # torch + state_type = type(state) + if state_type == list and hasattr(state[0], "device"): # torch state = [ - tensor.to(devices[i]) - if devices[i] != torch.device("cpu") - else tensor.clone() + ( + tensor.to(devices[i]) + if devices[i] != torch.device("cpu") + else tensor.clone() + ) for i, tensor in enumerate(state) ] logits = ( @@ -250,7 +255,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): if logits_device != torch.device("cpu") else logits.clone() ) - else: # rwkv.cpp, WebGPU + elif state_type == np.ndarray: # rwkv.cpp + logits = np.copy(logits) + else: # WebGPU logits = np.copy(logits) quick_log(request, body, "Hit:\n" + prompt) diff --git a/backend-python/rwkv_pip/webgpu/model.py b/backend-python/rwkv_pip/webgpu/model.py index 950e49a..7c4c085 100644 --- a/backend-python/rwkv_pip/webgpu/model.py +++ b/backend-python/rwkv_pip/webgpu/model.py @@ -13,13 +13,6 @@ except ModuleNotFoundError: class RWKV: def __init__(self, model_path: str, strategy: str = None): - self.info = wrp.peek_info(model_path) - self.w = {} # fake weight - self.w["emb.weight"] = [0] * self.info.num_vocab - self.version = str(self.info.version).lower() - self.wrp = getattr(wrp, self.version) - self.version = float(self.version.replace("v", "")) - layer = ( int(s.lstrip("layer")) for s in strategy.split() @@ -33,21 +26,25 @@ class RWKV: for s in s.split(",") if s.startswith("chunk") ) + self.token_chunk_size = next(chunk_size, 32) args = { - "file": model_path, - "turbo": True, + "path": model_path, "quant": next(layer, 31) if "i8" in strategy else 0, "quant_nf4": next(layer, 26) if "i4" in strategy else 0, - "token_chunk_size": next(chunk_size, 32), - "lora": None, } - self.model = self.wrp.Model(**args) + self.model = wrp.Model(**args) + self.info = self.model.info() + self.w = {} # fake weight + self.w["emb.weight"] = [0] * self.info.num_vocab + self.version = str(self.info.version).lower() + self.version = float(self.version.lower().replace("v", "")) def forward(self, tokens: List[int], state: Union[Any, None] = None): - if type(state).__name__ == "BackedState": # memory state - gpu_state = self.wrp.ModelState(self.model, 1) - gpu_state.load(state) - else: - gpu_state = state - return self.wrp.run_one(self.model, tokens, gpu_state) + if state is None: + self.model.clear_state() + elif type(state).__name__ == "State_Cpu": + self.model.load_state(state) + logits = self.model.run(tokens, self.token_chunk_size) + ret_state = "State_Gpu" + return logits, ret_state diff --git a/backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd b/backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd index 50699c6..481dcab 100644 Binary files a/backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd and b/backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd differ