From 18d4b2304e90f47e9b19bcf26552c838caac4158 Mon Sep 17 00:00:00 2001 From: josc146 Date: Thu, 14 Dec 2023 20:39:42 +0800 Subject: [PATCH] WebGPU (Python) strategy --- backend-python/rwkv_pip/webgpu/model.py | 9 +++++++-- frontend/src/pages/Configs.tsx | 6 +++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/backend-python/rwkv_pip/webgpu/model.py b/backend-python/rwkv_pip/webgpu/model.py index 46529bb..55f8321 100644 --- a/backend-python/rwkv_pip/webgpu/model.py +++ b/backend-python/rwkv_pip/webgpu/model.py @@ -12,8 +12,13 @@ except ModuleNotFoundError: class RWKV: - def __init__(self, model_path: str, strategy=None): - self.model = wrp.v5.Model(model_path, turbo=False) + def __init__(self, model_path: str, strategy: str = None): + self.model = wrp.v5.Model( + model_path, + turbo=False, + quant=32 if "i8" in strategy else None, + quant_nf4=26 if "i4" in strategy else None, + ) self.w = {} # fake weight self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab diff --git a/frontend/src/pages/Configs.tsx b/frontend/src/pages/Configs.tsx index ea633d9..f0a4335 100644 --- a/frontend/src/pages/Configs.tsx +++ b/frontend/src/pages/Configs.tsx @@ -282,7 +282,7 @@ const Configs: FC = observer(() => { selectedConfig.modelParameters.device !== 'Custom' && { {selectedConfig.modelParameters.device !== 'CPU' && selectedConfig.modelParameters.device !== 'MPS' && } {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && } - {selectedConfig.modelParameters.device === 'WebGPU' && } - {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && selectedConfig.modelParameters.device !== 'WebGPU' && + {selectedConfig.modelParameters.device.startsWith('WebGPU') && } + {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && !selectedConfig.modelParameters.device.startsWith('WebGPU') && } {selectedConfig.modelParameters.device === 'CPU (rwkv.cpp)' && }