diff --git a/.gitignore b/.gitignore index 51d7dbd..2c84165 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,7 @@ __pycache__ /backend-python/get-pip.py /py310 *.zip -/cmd-helper.bat \ No newline at end of file +/cmd-helper.bat +/backend-python/wkv_cuda +*.exe +*.old diff --git a/backend-golang/rwkv.go b/backend-golang/rwkv.go index 7875b14..c41bdb1 100644 --- a/backend-golang/rwkv.go +++ b/backend-golang/rwkv.go @@ -48,7 +48,7 @@ func (a *App) InstallPyDep(cnMirror bool) (string, error) { return "", err } ChangeFileLine("./py310/python310._pth", 3, "Lib\\site-packages") - _, err = Cmd(python, "-m", "pip", "install", "torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cu117") + _, err = Cmd(python, "-m", "pip", "install", "torch==1.13.1", "torchvision==0.14.1", "torchaudio==0.13.1", "--index-url", "https://download.pytorch.org/whl/cu117") if err != nil { return "", err } diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index a03e258..7a5cb6a 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -1,3 +1,5 @@ +import os +import pathlib from typing import Dict from langchain.llms import RWKV from pydantic import BaseModel @@ -34,6 +36,10 @@ def get_rwkv_config(model: RWKV) -> ModelConfigBody: ) +# os.environ["RWKV_CUDA_ON"] = '1' +# os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}" + + def rwkv_generate(model: RWKV, prompt: str, stop: str = None): model.model_state = None model.model_tokens = [] diff --git a/backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd b/backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd new file mode 100644 index 0000000..09927ff Binary files /dev/null and b/backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd differ diff --git a/backend-python/wkv_cuda_utils/wkv_cuda40.pyd b/backend-python/wkv_cuda_utils/wkv_cuda40.pyd new file mode 100644 index 0000000..efd7f74 Binary files /dev/null and b/backend-python/wkv_cuda_utils/wkv_cuda40.pyd differ diff --git a/backend-python/wkv_cuda_utils/wkv_cuda_model.py b/backend-python/wkv_cuda_utils/wkv_cuda_model.py new file mode 100644 index 0000000..7d17727 --- /dev/null +++ b/backend-python/wkv_cuda_utils/wkv_cuda_model.py @@ -0,0 +1,734 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import types, gc, os, time, re +import torch +from torch.nn import functional as F +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +current_path = os.path.dirname(os.path.abspath(__file__)) + +# https://zhuanlan.zhihu.com/p/612879065 +def LoadPreCompileLibrary(file): + import importlib + import os + + import torch + + # load the custom_op_library and register the custom ops + lib_dir = os.path.dirname(__file__) + if os.name == "nt": + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise ValueError(err) + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = ( + importlib.machinery.ExtensionFileLoader, + importlib.machinery.EXTENSION_SUFFIXES, + ) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec(file) + if ext_specs is None: + return False + + try: + torch.ops.load_library(ext_specs.origin) + except OSError as exc: + return False + return True + +######################################################################################################## + +if os.environ.get('RWKV_JIT_ON') != '0': + os.environ["RWKV_JIT_ON"] = '1' + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + MyStatic = torch.jit.script +else: + MyModule = torch.nn.Module + def __nop(ob): + return ob + MyFunction = __nop + MyStatic = __nop + +if os.environ.get('RWKV_CUDA_ON') == '1': + if LoadPreCompileLibrary('wkv_cuda') is False: + from torch.utils.cpp_extension import load + load( + name=f"wkv_cuda", + sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu"], + verbose=True, + extra_cuda_cflags=["-t 4", "-std=c++17", "--use_fast_math", "-O3", "--extra-device-vectorization"], + is_python_module=False) + + @MyStatic + def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp): + assert 1 * C % min(C, 32) == 0 + assert k.dtype == v.dtype == torch.float16 or k.dtype == v.dtype == torch.float32 + assert w.dtype == u.dtype == aa.dtype == bb.dtype == pp.dtype == torch.float32 + w = w.contiguous() + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + y = torch.empty((T, C), device=w.device, memory_format=torch.contiguous_format, dtype=k.dtype) + torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) + return y, aa, bb, pp + @MyStatic + def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype + assert x.dtype == torch.float32 or x.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == [B, N] + assert w.shape == [N, M] + assert rx.shape == mx.shape == [M] + assert ry.shape == my.shape == [N, 1] + y = torch.empty((B, M), device=w.device, dtype=x.dtype) + torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y) + return y + @MyStatic + def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype + assert x.dtype == torch.float32 or x.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == [N] + assert w.shape == [N, M] + assert rx.shape == mx.shape == [M] + assert ry.shape == my.shape == [N, 1] + y = torch.zeros((M,), device=w.device, dtype=torch.float32) + torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) + return y.to(dtype=x.dtype) +else: + os.environ["RWKV_CUDA_ON"] = '0' + +######################################################################################################## + +class RWKV(MyModule): + def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None): + super().__init__() + if verbose: + prxxx = lambda *args, **kwargs: print(*args, **kwargs) + else: + prxxx = lambda *args, **kwargs: None + + STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" + if not re.match(STRATEGY_REGEX, strategy): + raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/") + + strategy = ('->'.join([x.strip() for x in strategy.split('->')])).replace('->', ' -> ') + self.args = types.SimpleNamespace() + args = self.args + args.MODEL_NAME = model + args.strategy_string = strategy + + # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow) + self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0 + prxxx(f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n') + + args.MODEL_NAME = args.MODEL_NAME.strip() + if not args.MODEL_NAME.endswith('.pth'): + args.MODEL_NAME += '.pth' + prxxx(f'Loading {args.MODEL_NAME} ...') + with torch.no_grad(): + self.w = torch.load(args.MODEL_NAME, map_location='cpu') # load model to CPU first + gc.collect() + w = self.w + + ALREADY_CONVERTED = False + if '_strategy' in w: + ALREADY_CONVERTED = True + assert convert_and_save_and_exit == None # you should only convert a raw model + prxxx(f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n") + assert w['_strategy'] == args.strategy_string # if you are using a new strategy, re-convert the model + assert float(w['_version']) >= 0.7 # sometimes you should re-convert using latest convert_model.py + assert w['_rescale_layer'] == self.RESCALE_LAYER + del w['_strategy'] + del w['_version'] + del w['_rescale_layer'] + + args.n_embd = w['emb.weight'].shape[1] + args.n_layer = 0 + keys = list(w.keys()) + for x in keys: + layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 + args.n_layer = max(args.n_layer, layer_id+1) + + ####################### Compute strategy + + s = [x.strip().split(' ') for x in strategy.split('->')] + plan = [0] * len(s) + stream_i = -1 + stream_count = 0 + to_allocate = args.n_layer + 1 + allocated = 0 + free_slots = 0 + for i in range(len(s)): + si = s[i] + si1 = si[1] + if si1.startswith('fp32'): si[1] = [torch.float] + elif si1.startswith('fp16'): si[1] = [torch.float16] + elif si1.startswith('bf16'): si[1] = [torch.bfloat16] + if si1.endswith('i8'): si[1] += [torch.uint8] + else: si[1] += [si[1][0]] + if len(si) > 2: + ss = si[2] + assert ss.startswith('*') + if ss.endswith('+'): + plan[i] = int(ss[1:-1]) + stream_i = i + else: + plan[i] = int(ss[1:]) + allocated += plan[i] + if allocated >= to_allocate: + plan[i] += to_allocate - allocated + break + else: + free_slots += 1 + if stream_i < 0: + if free_slots > 0 and to_allocate > allocated: + for i in range(len(s)): + if plan[i] == 0: + plan[i] = (to_allocate - allocated) // free_slots + allocated += plan[i] + free_slots -= 1 + if to_allocate > allocated: + plan[len(s)-1] += to_allocate - allocated + else: + if to_allocate > allocated: + stream_count = to_allocate - allocated + plan[stream_i] += stream_count + prxxx(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)') + for i in range(len(s)): + ss = s[i] + if i != stream_i: + prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers') + else: + prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers') + plan[i] += (0 if i == 0 else plan[i-1]) + self.strategy = [None] * (args.n_layer + 1) + strategy = self.strategy + for n in range(args.n_layer + 1): + for i in range(len(s)): + if n < plan[i]: + strategy[n] = types.SimpleNamespace() + strategy[n].device = s[i][0] + strategy[n].atype = s[i][1][0] + strategy[n].wtype = s[i][1][1] + strategy[n].stream = False + if i == stream_i and n >= (plan[i] - stream_count): + strategy[n].stream = True + break + prxxx(f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",end=' ') + prxxx() + + ####################### Load weights to self.w + + if not ALREADY_CONVERTED: + try: # precompute embedding + w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) + except: + w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float()) + del w['blocks.0.ln0.weight'] + del w['blocks.0.ln0.bias'] + + print_need_newline = False + keys = list(w.keys()) + for x in keys: + w[x].requires_grad = False + layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 + if ('ln_out.' in x) or ('head.' in x): + layer_id = args.n_layer + dd = strategy[layer_id] + DEVICE = dd.device + ATYPE = dd.atype + WTYPE = dd.wtype + + if not ALREADY_CONVERTED: + if self.RESCALE_LAYER > 0: + if 'att.output.weight' in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + if 'ffn.value.weight' in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + + if '.time_' in x: + w[x] = w[x].squeeze() + if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x or 'head.weight' in x: + w[x] = w[x].t() + + if '.time_decay' in x: # need fp32 for this + w[x] = -torch.exp(w[x].float()) + elif '.time_first' in x: # need fp32 for this + w[x] = w[x].float() + else: + if (len(w[x].shape) == 2) and ('emb' not in x): + if WTYPE != torch.uint8: + w[x] = w[x].to(dtype=WTYPE) + else: + w[x] = w[x].float() + + if w[x].shape[0] > w[x].shape[1]: + w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x+'_my'] + w[x+'_mx'] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x+'_mx'] + w[x+'_rx'] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x+'_rx'] + w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x+'_ry'] + else: + w[x+'_mx'] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x+'_mx'] + w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x+'_my'] + w[x+'_rx'] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x+'_rx'] + w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x+'_ry'] + + w[x] = torch.clip(torch.floor(w[x] * 256), min=0, max=255).to(dtype=torch.uint8) + w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous() + w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous() + w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous() + w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous() + else: + w[x] = w[x].to(dtype=ATYPE) + + if convert_and_save_and_exit == None: + if 'emb.' in x: + w[x] = w[x].contiguous() + elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')): + try: + w[x] = w[x].contiguous().pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :) + except: + print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.') + elif DEVICE != 'cpu': + w[x] = w[x].to(device=DEVICE).contiguous() + + if (dd.stream) or (DEVICE != 'cpu'): + try: + w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE).contiguous() + w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE).contiguous() + w[x+'_my'] = w[x+'_my'].to(device=DEVICE).contiguous() + w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE).contiguous() + except: + pass + + if 'ffn.value.weight' in x: + gc.collect() + if 'cuda' in args.strategy_string: + torch.cuda.empty_cache() + + shape = [i for i in w[x].shape if i != 1] + if len(shape) > 1: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" + else: + shape = f" {str(shape[0]).rjust(5)} " + if layer_id == 0 or layer_id >= args.n_layer-1: + if print_need_newline: + prxxx('\n', end = '') + print_need_newline = False + dt = str(w[x].dtype).replace('torch.', '') + dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8') + prxxx(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '') + else: + print_need_newline = True + prxxx('.', end = '', flush = True) + + if convert_and_save_and_exit: + w['_strategy'] = args.strategy_string + w['_rescale_layer'] = self.RESCALE_LAYER + w['_version'] = '0.7' + if not convert_and_save_and_exit.endswith('.pth'): + convert_and_save_and_exit += '.pth' + prxxx(f'Saving to {convert_and_save_and_exit}...') + torch.save(w, convert_and_save_and_exit) + prxxx(f'Converted and saved. Now this will exit.') + exit(0) + + gc.collect() + if 'cuda' in args.strategy_string: + torch.cuda.empty_cache() + + @MyFunction + def torch_mm8_seq(self, x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + @MyFunction + def torch_mm8_one(self, x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + if os.environ.get('RWKV_CUDA_ON') == '1': + @MyFunction + def mm8_seq(self, x, w, mx, rx, my, ry): + if w.device.type == 'cuda' and x.dtype == torch.float16: + B, N, M = x.shape[0], w.shape[0], w.shape[1] + return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry) + else: + return self.torch_mm8_seq(x, w, mx, rx, my, ry) + @MyFunction + def mm8_one(self, x, w, mx, rx, my, ry): + if w.device.type == 'cuda': + N, M = w.shape[0], w.shape[1] + return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) + else: + return self.torch_mm8_one(x, w, mx, rx, my, ry) + else: + @MyFunction + def mm8_seq(self, x, w, mx, rx, my, ry): + return self.torch_mm8_seq(x, w, mx, rx, my, ry) + @MyFunction + def mm8_one(self, x, w, mx, rx, my, ry): + return self.torch_mm8_one(x, w, mx, rx, my, ry) + + ######################################################################################################## + + @MyFunction + def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + vx = torch.square(torch.relu(kx @ kw)) + out = r * (vx @ vw) + return x + out, xx + + @MyFunction + def ffn_one_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry))) + out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)) + return x + out, xx + + ######################################################################################################## + + @MyFunction + def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + vx = torch.square(torch.relu(kx @ kw)) + out = r * (vx @ vw) + return x + out, xx[-1,:] + + @MyFunction + def ffn_seq_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry))) + out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)) + return x + out, xx[-1,:] + + ######################################################################################################## + + @MyFunction + def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + k = (kx @ kw).float() + v = (vx @ vw).float() + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + out = (r * wkv) @ ow + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + @MyFunction + def att_one_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) + k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float() + v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float() + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory) + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + ######################################################################################################## + + @MyFunction + def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + k = (kx @ kw).float() + v = (vx @ vw).float() + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = (r * sx) @ ow + return x + out, xx[-1,:], aa, bb, pp + + @MyFunction + def att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float() + v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float() + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory) + return x + out, xx[-1,:], aa, bb, pp + + ######################################################################################################## + + if os.environ["RWKV_CUDA_ON"] == '1': + @MyFunction + def cuda_att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + T, C = x.size() + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + k = kx @ kw + v = vx @ vw + y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) + + out = (r * y) @ ow + return x + out, xx[-1,:], aa, bb, pp + + @MyFunction + def cuda_att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): + T, C = x.size() + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry) + v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry) + y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) + + out = self.mm8_seq(r * y, ow, omx, orx, omy, ory) + return x + out, xx[-1,:], aa, bb, pp + + ######################################################################################################## + + def forward(self, tokens, state, full_output=False): + with torch.no_grad(): + w = self.w + args = self.args + + if state == None: + state = [None] * args.n_layer * 5 + for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i*5+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + state[i*5+1] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+2] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() + state[i*5+3] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() - 1e30 + state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() + + seq_mode = len(tokens) > 1 + + x = w['emb.weight'][tokens if seq_mode else tokens[0]] + + for i in range(args.n_layer): + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + wtype = dd.wtype + if seq_mode: + if 'cuda' in str(dev) and os.environ["RWKV_CUDA_ON"] == '1': + ATT = self.cuda_att_seq if wtype != torch.uint8 else self.cuda_att_seq_i8 + else: + ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 + FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 + else: + ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 + FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 + + x = x.to(dtype=atype, device=dev) + + kw = w[f'{att}key.weight'] + vw = w[f'{att}value.weight'] + rw = w[f'{att}receptance.weight'] + ow = w[f'{att}output.weight'] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + kmx = w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x + krx = w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x + kmy = w[f'{att}key.weight_my'] if wtype == torch.uint8 else x + kry = w[f'{att}key.weight_ry'] if wtype == torch.uint8 else x + vmx = w[f'{att}value.weight_mx'] if wtype == torch.uint8 else x + vrx = w[f'{att}value.weight_rx'] if wtype == torch.uint8 else x + vmy = w[f'{att}value.weight_my'] if wtype == torch.uint8 else x + vry = w[f'{att}value.weight_ry'] if wtype == torch.uint8 else x + rmx = w[f'{att}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = w[f'{att}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = w[f'{att}receptance.weight_my'] if wtype == torch.uint8 else x + rry = w[f'{att}receptance.weight_ry'] if wtype == torch.uint8 else x + omx = w[f'{att}output.weight_mx'] if wtype == torch.uint8 else x + orx = w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x + omy = w[f'{att}output.weight_my'] if wtype == torch.uint8 else x + ory = w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3], + w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'], + w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'], + w[f'{att}time_decay'], w[f'{att}time_first'], + kw, vw, rw, ow, + kmx, krx, kmy, kry, + vmx, vrx, vmy, vry, + rmx, rrx, rmy, rry, + omx, orx, omy, ory, + ) + if dd.stream: + del kw, vw, rw, ow + + kw = w[f'{ffn}key.weight'] + vw = w[f'{ffn}value.weight'] + rw = w[f'{ffn}receptance.weight'] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + kmx = w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x + krx = w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x + kmy = w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x + kry = w[f'{ffn}key.weight_ry'] if wtype == torch.uint8 else x + vmx = w[f'{ffn}value.weight_mx'] if wtype == torch.uint8 else x + vrx = w[f'{ffn}value.weight_rx'] if wtype == torch.uint8 else x + vmy = w[f'{ffn}value.weight_my'] if wtype == torch.uint8 else x + vry = w[f'{ffn}value.weight_ry'] if wtype == torch.uint8 else x + rmx = w[f'{ffn}receptance.weight_mx'] if wtype == torch.uint8 else x + rrx = w[f'{ffn}receptance.weight_rx'] if wtype == torch.uint8 else x + rmy = w[f'{ffn}receptance.weight_my'] if wtype == torch.uint8 else x + rry = w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x + x, state[i*5+4] = FFN( + x, state[i*5+4], + w[f'{bbb}ln2.weight'], w[f'{bbb}ln2.bias'], + w[f'{ffn}time_mix_k'], w[f'{ffn}time_mix_r'], + kw, vw, rw, + kmx, krx, kmy, kry, + vmx, vrx, vmy, vry, + rmx, rrx, rmy, rry, + ) + if dd.stream: + del kw, vw, rw + + if self.RESCALE_LAYER > 0: + if (i+1) % self.RESCALE_LAYER == 0: + x = x / 2 + + dd = self.strategy[args.n_layer] + x = x[-1,:] if (seq_mode and (not full_output)) else x + x = x.to(dtype=dd.atype, device=dd.device) + + x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias']) + if w['head.weight'].dtype != torch.uint8: + x = x @ w['head.weight'] + else: + if seq_mode and full_output: + x = self.mm8_seq(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) + else: + x = self.mm8_one(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) + + return x.float(), state