From 0ddd2e9fea506d13a8f4a3e1e92856440c918c9f Mon Sep 17 00:00:00 2001 From: josc146 Date: Thu, 14 Dec 2023 18:37:07 +0800 Subject: [PATCH] add WebGPU Python Mode (https://github.com/cryscan/web-rwkv-py) --- .github/workflows/release.yml | 2 ++ backend-golang/rwkv.go | 16 ++++++++- backend-python/convert_safetensors.py | 27 ++++++++++++++ backend-python/main.py | 5 +++ backend-python/routes/completion.py | 3 +- backend-python/routes/state_cache.py | 39 +++++++++++++++------ backend-python/rwkv_pip/utils.py | 2 ++ backend-python/rwkv_pip/webgpu/model.py | 21 +++++++++++ backend-python/utils/rwkv.py | 13 +++++-- frontend/src/components/RunButton.tsx | 13 +++---- frontend/src/pages/Configs.tsx | 16 +++++---- frontend/src/types/configs.ts | 2 +- frontend/src/utils/convert-model.ts | 15 ++++++-- frontend/src/utils/index.tsx | 3 +- frontend/wailsjs/go/backend_golang/App.d.ts | 4 ++- frontend/wailsjs/go/backend_golang/App.js | 8 +++-- 16 files changed, 155 insertions(+), 34 deletions(-) create mode 100644 backend-python/rwkv_pip/webgpu/model.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4fd50e9..7cc7f85 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -98,6 +98,7 @@ jobs: rm ./backend-python/get-pip.py rm ./backend-python/rwkv_pip/cpp/librwkv.dylib rm ./backend-python/rwkv_pip/cpp/rwkv.dll + rm ./backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd make mv build/bin/RWKV-Runner build/bin/RWKV-Runner_linux_x64 @@ -124,6 +125,7 @@ jobs: rm ./backend-python/get-pip.py rm ./backend-python/rwkv_pip/cpp/rwkv.dll rm ./backend-python/rwkv_pip/cpp/librwkv.so + rm ./backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd make cp build/darwin/Readme_Install.txt build/bin/Readme_Install.txt cp build/bin/RWKV-Runner.app/Contents/MacOS/RWKV-Runner build/bin/RWKV-Runner_darwin_universal diff --git a/backend-golang/rwkv.go b/backend-golang/rwkv.go index a6c5d8f..e6521c2 100644 --- a/backend-golang/rwkv.go +++ b/backend-golang/rwkv.go @@ -10,7 +10,7 @@ import ( "strings" ) -func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool, rwkvcpp bool) (string, error) { +func (a *App) StartServer(python string, port int, host string, webui bool, rwkvBeta bool, rwkvcpp bool, webgpu bool) (string, error) { var err error if python == "" { python, err = GetPython() @@ -28,6 +28,9 @@ func (a *App) StartServer(python string, port int, host string, webui bool, rwkv if rwkvcpp { args = append(args, "--rwkv.cpp") } + if webgpu { + args = append(args, "--webgpu") + } args = append(args, "--port", strconv.Itoa(port), "--host", host) return Cmd(args...) } @@ -55,6 +58,17 @@ func (a *App) ConvertSafetensors(modelPath string, outPath string) (string, erro return Cmd(args...) } +func (a *App) ConvertSafetensorsWithPython(python string, modelPath string, outPath string) (string, error) { + var err error + if python == "" { + python, err = GetPython() + } + if err != nil { + return "", err + } + return Cmd(python, "./backend-python/convert_safetensors.py", "--input", modelPath, "--output", outPath) +} + func (a *App) ConvertGGML(python string, modelPath string, outPath string, Q51 bool) (string, error) { var err error if python == "" { diff --git a/backend-python/convert_safetensors.py b/backend-python/convert_safetensors.py index 6637b2e..131ee8f 100644 --- a/backend-python/convert_safetensors.py +++ b/backend-python/convert_safetensors.py @@ -30,6 +30,33 @@ def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names= if "state_dict" in loaded: loaded = loaded["state_dict"] + kk = list(loaded.keys()) + version = 4 + for x in kk: + if "ln_x" in x: + version = max(5, version) + if "gate.weight" in x: + version = max(5.1, version) + if int(version) == 5 and "att.time_decay" in x: + if len(loaded[x].shape) > 1: + if loaded[x].shape[1] > 1: + version = max(5.2, version) + if "time_maa" in x: + version = max(6, version) + + if version == 5.1 and "midi" in pt_filename.lower(): + import numpy as np + + np.set_printoptions(precision=4, suppress=True, linewidth=200) + kk = list(loaded.keys()) + _, n_emb = loaded["emb.weight"].shape + for k in kk: + if "time_decay" in k or "time_faaaa" in k: + # print(k, mm[k].shape) + loaded[k] = ( + loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0]) + ) + loaded = {k: v.clone().half() for k, v in loaded.items()} # for k, v in loaded.items(): # print(f'{k}\t{v.shape}\t{v.dtype}') diff --git a/backend-python/main.py b/backend-python/main.py index 1ef38c9..92d6b20 100644 --- a/backend-python/main.py +++ b/backend-python/main.py @@ -37,6 +37,11 @@ def get_args(args: Union[Sequence[str], None] = None): action="store_true", help="whether to use rwkv.cpp (default: False)", ) + group.add_argument( + "--webgpu", + action="store_true", + help="whether to use webgpu (default: False)", + ) args = parser.parse_args(args) return args diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index 59d3250..9449e8f 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -8,7 +8,6 @@ import base64 from fastapi import APIRouter, Request, status, HTTPException from sse_starlette.sse import EventSourceResponse from pydantic import BaseModel, Field -import numpy as np import tiktoken from utils.rwkv import * from utils.log import quick_log @@ -396,6 +395,8 @@ class EmbeddingsBody(BaseModel): def embedding_base64(embedding: List[float]) -> str: + import numpy as np + return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8") diff --git a/backend-python/routes/state_cache.py b/backend-python/routes/state_cache.py index 16693c4..33bf74f 100644 --- a/backend-python/routes/state_cache.py +++ b/backend-python/routes/state_cache.py @@ -87,18 +87,34 @@ def add_state(body: AddStateBody): raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") import torch + import numpy as np try: + devices: List[torch.device] = [] + state: Union[Any, None] = None + + if body.state is not None: + if type(body.state) == list or type(body.state) == np.ndarray: + devices = [ + ( + tensor.device + if hasattr(tensor, "device") + else torch.device("cpu") + ) + for tensor in body.state + ] + state = ( + [tensor.cpu() for tensor in body.state] + if hasattr(body.state[0], "device") + else copy.deepcopy(body.state) + ) + else: + pass # WebGPU + id: int = trie.insert(body.prompt) - devices: List[torch.device] = [ - (tensor.device if hasattr(tensor, "device") else torch.device("cpu")) - for tensor in body.state - ] dtrie[id] = { "tokens": copy.deepcopy(body.tokens), - "state": [tensor.cpu() for tensor in body.state] - if hasattr(body.state[0], "device") - else copy.deepcopy(body.state), + "state": state, "logits": copy.deepcopy(body.logits), "devices": devices, } @@ -174,6 +190,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded") import torch + import numpy as np id = -1 try: @@ -185,14 +202,16 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request): v = dtrie[id] devices: List[torch.device] = v["devices"] prompt: str = trie[id] + state: Union[Any, None] = v["state"] + + if state is not None and type(state) == list and hasattr(state[0], "device"): + state = [tensor.to(devices[i]) for i, tensor in enumerate(state)] quick_log(request, body, "Hit:\n" + prompt) return { "prompt": prompt, "tokens": v["tokens"], - "state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])] - if hasattr(v["state"][0], "device") - else v["state"], + "state": state, "logits": v["logits"], } else: diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py index a0f8ea4..f492ec2 100644 --- a/backend-python/rwkv_pip/utils.py +++ b/backend-python/rwkv_pip/utils.py @@ -84,6 +84,8 @@ class PIPELINE: return e / e.sum(axis=axis, keepdims=True) def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0): + if type(logits) == list: + logits = np.array(logits) np_logits = type(logits) == np.ndarray if np_logits: probs = self.np_softmax(logits, axis=-1) diff --git a/backend-python/rwkv_pip/webgpu/model.py b/backend-python/rwkv_pip/webgpu/model.py new file mode 100644 index 0000000..46529bb --- /dev/null +++ b/backend-python/rwkv_pip/webgpu/model.py @@ -0,0 +1,21 @@ +from typing import Any, List, Union + +try: + import web_rwkv_py as wrp +except ModuleNotFoundError: + try: + from . import web_rwkv_py as wrp + except ImportError: + raise ModuleNotFoundError( + "web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py" + ) + + +class RWKV: + def __init__(self, model_path: str, strategy=None): + self.model = wrp.v5.Model(model_path, turbo=False) + self.w = {} # fake weight + self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab + + def forward(self, tokens: List[int], state: Union[Any, None] = None): + return wrp.v5.run_one(self.model, tokens, state) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index ed60e81..8414ed3 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -8,7 +8,6 @@ from typing import Dict, Iterable, List, Tuple, Union, Type from utils.log import quick_log from fastapi import HTTPException from pydantic import BaseModel, Field -import numpy as np from routes import state_cache import global_var @@ -68,6 +67,8 @@ class AbstractRWKV(ABC): pass def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]: + import numpy as np + if fast_mode: embedding, token_len = self.__fast_embedding( self.fix_tokens(self.pipeline.encode(input)), None @@ -222,6 +223,8 @@ class AbstractRWKV(ABC): def generate( self, prompt: str, stop: Union[str, List[str], None] = None ) -> Iterable[Tuple[str, str, int, int]]: + import numpy as np + quick_log(None, None, "Generation Prompt:\n" + prompt) cache = None delta_prompt = prompt @@ -231,7 +234,7 @@ class AbstractRWKV(ABC): ) except HTTPException: pass - if cache is None or cache["prompt"] == "": + if cache is None or cache["prompt"] == "" or cache["state"] is None: self.model_state = None self.model_tokens = [] else: @@ -511,6 +514,7 @@ def get_tokenizer(tokenizer_len: int): def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV: rwkv_beta = global_var.get(global_var.Args).rwkv_beta rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp") + webgpu = global_var.get(global_var.Args).webgpu if "midi" in model.lower() or "abc" in model.lower(): os.environ["RWKV_RESCALE_LAYER"] = "999" @@ -526,6 +530,11 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV from rwkv_pip.cpp.model import ( RWKV as Model, ) + elif webgpu: + print("Using webgpu") + from rwkv_pip.webgpu.model import ( + RWKV as Model, + ) else: from rwkv_pip.model import ( RWKV as Model, diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx index 1020c9a..534a3bc 100644 --- a/frontend/src/components/RunButton.tsx +++ b/frontend/src/components/RunButton.tsx @@ -48,6 +48,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean const modelConfig = commonStore.getCurrentModelConfig(); const webgpu = modelConfig.modelParameters.device === 'WebGPU'; + const webgpuPython = modelConfig.modelParameters.device === 'WebGPU (Python)'; const cpp = modelConfig.modelParameters.device === 'CPU (rwkv.cpp)'; let modelName = ''; let modelPath = ''; @@ -77,7 +78,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean }); }; - if (webgpu) { + if (webgpu || webgpuPython) { if (!['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) { const stModelPath = modelPath.replace(/\.pth$/, '.st'); if (await FileExists(stModelPath)) { @@ -92,7 +93,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean return; } else { toastWithButton(t('Please convert model to safe tensors format first'), t('Convert'), () => { - convertToSt(modelConfig); + convertToSt(modelConfig, navigate); }); commonStore.setStatus({ status: ModelStatus.Offline }); return; @@ -100,7 +101,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean } } - if (!webgpu) { + if (!webgpu && !webgpuPython) { if (['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) { toast(t('Please change Strategy to WebGPU to use safetensors format'), { type: 'error' }); commonStore.setStatus({ status: ModelStatus.Offline }); @@ -176,7 +177,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean const isUsingCudaBeta = modelConfig.modelParameters.device === 'CUDA-Beta'; startServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1', - !!modelConfig.enableWebUI, isUsingCudaBeta, cpp + !!modelConfig.enableWebUI, isUsingCudaBeta, cpp, webgpuPython ).catch((e) => { const errMsg = e.message || e; if (errMsg.includes('path contains space')) @@ -216,7 +217,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean const strategy = getStrategy(modelConfig); let customCudaFile = ''; - if ((modelConfig.modelParameters.device.includes('CUDA') || modelConfig.modelParameters.device === 'Custom') + if ((modelConfig.modelParameters.device.startsWith('CUDA') || modelConfig.modelParameters.device === 'Custom') && modelConfig.modelParameters.useCustomCuda && !strategy.split('->').some(s => ['cuda', 'fp32'].every(v => s.includes(v)))) { if (commonStore.platform === 'windows') { @@ -264,7 +265,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean navigate({ pathname: '/' + buttonName.toLowerCase() }); }; - if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'CUDA-Beta') && + if (modelConfig.modelParameters.device.startsWith('CUDA') && modelConfig.modelParameters.storedLayers < modelConfig.modelParameters.maxStoredLayers && commonStore.monitorData && commonStore.monitorData.totalVram !== 0 && (commonStore.monitorData.usedVram / commonStore.monitorData.totalVram) < 0.9) diff --git a/frontend/src/pages/Configs.tsx b/frontend/src/pages/Configs.tsx index a3c29c5..ea633d9 100644 --- a/frontend/src/pages/Configs.tsx +++ b/frontend/src/pages/Configs.tsx @@ -246,7 +246,7 @@ const Configs: FC = observer(() => { } /> { - selectedConfig.modelParameters.device !== 'WebGPU' ? + !selectedConfig.modelParameters.device.startsWith('WebGPU') ? (selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' ? { onClick={() => convertToGGML(selectedConfig, navigate)} />) : convertToSt(selectedConfig)} /> + onClick={() => convertToSt(selectedConfig, navigate)} /> } { + } /> @@ -281,7 +282,8 @@ const Configs: FC = observer(() => { selectedConfig.modelParameters.device !== 'Custom' && { @@ -302,12 +304,12 @@ const Configs: FC = observer(() => { } /> } { - selectedConfig.modelParameters.device.includes('CUDA') && + selectedConfig.modelParameters.device.startsWith('CUDA') && {getStrategy(selectedConfig)} } /> } { - selectedConfig.modelParameters.device.includes('CUDA') && + selectedConfig.modelParameters.device.startsWith('CUDA') && { }} /> } /> } - {selectedConfig.modelParameters.device.includes('CUDA') &&
} + {selectedConfig.modelParameters.device.startsWith('CUDA') &&
} { displayStrategyImg && { } {selectedConfig.modelParameters.device === 'Custom' &&
} { - (selectedConfig.modelParameters.device.includes('CUDA') || selectedConfig.modelParameters.device === 'Custom') && + (selectedConfig.modelParameters.device.startsWith('CUDA') || selectedConfig.modelParameters.device === 'Custom') && { +export const convertToSt = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => { + const webgpuPython = selectedConfig.modelParameters.device === 'WebGPU (Python)'; + if (webgpuPython) { + const ok = await checkDependencies(navigate); + if (!ok) + return; + } + const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`; if (await FileExists(modelPath)) { toast(t('Start Converting'), { autoClose: 2000, type: 'info' }); const newModelPath = modelPath.replace(/\.pth$/, '.st'); - ConvertSafetensors(modelPath, newModelPath).then(async () => { + const convert = webgpuPython ? + (input: string, output: string) => ConvertSafetensorsWithPython(commonStore.settings.customPythonPath, input, output) + : ConvertSafetensors; + convert(modelPath, newModelPath).then(async () => { if (!await FileExists(newModelPath)) { if (commonStore.platform === 'windows' || commonStore.platform === 'linux') toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' }); diff --git a/frontend/src/utils/index.tsx b/frontend/src/utils/index.tsx index 6f5bfc4..f65acce 100644 --- a/frontend/src/utils/index.tsx +++ b/frontend/src/utils/index.tsx @@ -192,6 +192,7 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) => strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32'; break; case 'WebGPU': + case 'WebGPU (Python)': strategy += params.precision === 'nf4' ? 'fp16i4' : params.precision === 'int8' ? 'fp16i8' : 'fp16'; break; case 'CUDA': @@ -307,7 +308,7 @@ export function getServerRoot(defaultLocalPort: number, isCore: boolean = false) const coreCustomApiUrl = commonStore.settings.coreApiUrl.trim().replace(/\/$/, ''); if (isCore && coreCustomApiUrl) return coreCustomApiUrl; - + const defaultRoot = `http://127.0.0.1:${defaultLocalPort}`; if (commonStore.status.status !== ModelStatus.Offline) return defaultRoot; diff --git a/frontend/wailsjs/go/backend_golang/App.d.ts b/frontend/wailsjs/go/backend_golang/App.d.ts index b594cc1..62be3c8 100755 --- a/frontend/wailsjs/go/backend_golang/App.d.ts +++ b/frontend/wailsjs/go/backend_golang/App.d.ts @@ -16,6 +16,8 @@ export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Pr export function ConvertSafetensors(arg1:string,arg2:string):Promise; +export function ConvertSafetensorsWithPython(arg1:string,arg2:string,arg3:string):Promise; + export function CopyFile(arg1:string,arg2:string):Promise; export function DeleteFile(arg1:string):Promise; @@ -64,7 +66,7 @@ export function SaveJson(arg1:string,arg2:any):Promise; export function StartFile(arg1:string):Promise; -export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean,arg6:boolean):Promise; +export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean,arg6:boolean,arg7:boolean):Promise; export function StartWebGPUServer(arg1:number,arg2:string):Promise; diff --git a/frontend/wailsjs/go/backend_golang/App.js b/frontend/wailsjs/go/backend_golang/App.js index 3be1c13..203d8f8 100755 --- a/frontend/wailsjs/go/backend_golang/App.js +++ b/frontend/wailsjs/go/backend_golang/App.js @@ -30,6 +30,10 @@ export function ConvertSafetensors(arg1, arg2) { return window['go']['backend_golang']['App']['ConvertSafetensors'](arg1, arg2); } +export function ConvertSafetensorsWithPython(arg1, arg2, arg3) { + return window['go']['backend_golang']['App']['ConvertSafetensorsWithPython'](arg1, arg2, arg3); +} + export function CopyFile(arg1, arg2) { return window['go']['backend_golang']['App']['CopyFile'](arg1, arg2); } @@ -126,8 +130,8 @@ export function StartFile(arg1) { return window['go']['backend_golang']['App']['StartFile'](arg1); } -export function StartServer(arg1, arg2, arg3, arg4, arg5, arg6) { - return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5, arg6); +export function StartServer(arg1, arg2, arg3, arg4, arg5, arg6, arg7) { + return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5, arg6, arg7); } export function StartWebGPUServer(arg1, arg2) {