Files
RWKV-Runner/finetune/lora/v6/src/model.py
2024-05-28 22:35:47 +08:00

1954 lines
76 KiB
Python
Vendored
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import functools
from torch.utils.checkpoint import checkpoint as torch_checkpoint
import os, math, gc, importlib
import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
if importlib.util.find_spec("deepspeed"):
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from torch._lowrank import svd_lowrank
import bitsandbytes as bnb
from .infctx_module import *
from einops import rearrange
from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
LORA_CONFIG = {
"r": 0,
"alpha": 0,
"dropout": 0,
"parts": {"att", "ln", "time", "ffn"},
}
class LoraLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
assert bias == False, "Biased LoraLinear not supported"
r, alpha, dropout = (
LORA_CONFIG["r"],
LORA_CONFIG["alpha"],
LORA_CONFIG["dropout"],
)
self.lora_A = nn.Parameter(torch.empty(r, in_features))
self.lora_B = nn.Parameter(torch.empty(out_features, r))
self.lora_dropout = nn.Dropout(dropout)
self.scaling = alpha / r
self.r = r
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
self.pissa = False
self.is_quant = False
def pissa_init(self, svd_niter):
self.pissa = True
Ur, Sr, Vr = svd_lowrank(self.weight.data, self.r, niter=svd_niter)
Vhr = Vr.t()
lora_A = torch.diag(torch.sqrt(Sr)) @ Vhr
lora_B = Ur @ torch.diag(torch.sqrt(Sr))
self.lora_A.data = lora_A
self.lora_B.data = lora_B
self.weight.data = self.weight.data - lora_B @ lora_A
def quant(self, quant_type):
self.is_quant = True
self.quant_type = quant_type
if self.quant_type == "4bit":
self.weight.data, self.qstate = bnb.functional.quantize_4bit(
(self.weight.data).to("cuda")
)
elif self.quant_type == "nf4":
self.weight.data, self.qstate = bnb.functional.quantize_nf4(
(self.weight.data).to("cuda")
)
elif self.quant_type == "fp4":
self.weight.data, self.qstate = bnb.functional.quantize_fp4(
(self.weight.data).to("cuda")
)
def forward(self, x):
if self.is_quant:
if self.quant_type == "4bit":
if self.pissa:
return F.linear(
x,
bnb.functional.dequantize_4bit(
self.weight.data, quant_state=self.qstate
).to(torch.bfloat16),
) + F.linear(F.linear(x, self.lora_A), self.lora_B)
return F.linear(
x,
bnb.functional.dequantize_4bit(
self.weight.data, quant_state=self.qstate
).to(torch.bfloat16),
) + self.scaling * F.linear(
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
)
elif self.quant_type == "nf4":
if self.pissa:
return F.linear(
x,
bnb.functional.dequantize_nf4(
self.weight.data, quant_state=self.qstate
).to(torch.bfloat16),
) + F.linear(F.linear(x, self.lora_A), self.lora_B)
return F.linear(
x,
bnb.functional.dequantize_nf4(
self.weight.data, quant_state=self.qstate
).to(torch.bfloat16),
) + self.scaling * F.linear(
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
)
elif self.quant_type == "fp4":
if self.pissa:
return F.linear(
x,
bnb.functional.dequantize_fp4(
self.weight.data, quant_state=self.qstate
).to(torch.bfloat16),
) + F.linear(F.linear(x, self.lora_A), self.lora_B)
return F.linear(
x,
bnb.functional.dequantize_fp4(
self.weight.data, quant_state=self.qstate
).to(torch.bfloat16),
) + self.scaling * F.linear(
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
)
if self.pissa:
return F.linear(x, self.weight) + F.linear(
F.linear(x, self.lora_A), self.lora_B
)
return F.linear(x, self.weight) + self.scaling * F.linear(
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
)
@functools.wraps(LoraLinear)
def make_linear_att(*args, **kwargs):
if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
return LoraLinear(*args, **kwargs)
else:
return nn.Linear(*args, **kwargs)
@functools.wraps(LoraLinear)
def make_linear_ffn(*args, **kwargs):
if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
return LoraLinear(*args, **kwargs)
else:
return nn.Linear(*args, **kwargs)
try:
print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"])
except:
os.environ["RWKV_MY_TESTING"] = ""
def __nop(ob):
return ob
MyModule = nn.Module
MyFunction = __nop
if os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
########################################################################################################
# CUDA Kernel
########################################################################################################
if os.environ["WKV"] == "fla":
if "x060" in os.environ["RWKV_MY_TESTING"]:
if (
os.environ["RWKV_TRAIN_TYPE"] == "infctx"
and "x060" in os.environ["RWKV_MY_TESTING"]
):
def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s):
r = rearrange(r, "b l (h d) -> b h l d", h=H)
k = rearrange(k, "b l (h d) -> b h l d", h=H)
v = rearrange(v, "b l (h d) -> b h l d", h=H)
w = rearrange(-torch.exp(w), "b l (h d) -> b h l d", h=H)
o, state = chunk_rwkv6(
r, k, v, w, u=u, initial_state=s, output_final_state=True
)
x = rearrange(o, "b h l d -> b l (h d)")
return x, state
elif os.environ["RWKV_TRAIN_TYPE"] == "states":
def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s):
r = rearrange(r, "b l (h d) -> b h l d", h=H)
k = rearrange(k, "b l (h d) -> b h l d", h=H)
v = rearrange(v, "b l (h d) -> b h l d", h=H)
w = rearrange(-torch.exp(w), "b l (h d) -> b h l d", h=H)
o, _ = chunk_rwkv6(
r, k, v, w, u=u, initial_state=s, output_final_state=False
)
x = rearrange(o, "b h l d -> b l (h d)")
return x
else:
def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
r = rearrange(r, "b l (h d) -> b h l d", h=H)
k = rearrange(k, "b l (h d) -> b h l d", h=H)
v = rearrange(v, "b l (h d) -> b h l d", h=H)
w = rearrange(-torch.exp(w), "b l (h d) -> b h l d", h=H)
o, _ = chunk_rwkv6(
r, k, v, w, u=u, initial_state=None, output_final_state=False
)
x = rearrange(o, "b h l d -> b l (h d)")
return x
else:
from torch.utils.cpp_extension import load
HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"])
if "x060" in os.environ["RWKV_MY_TESTING"]:
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
wkv6state_cuda = load(
name="wkv6infctx",
sources=[
"finetune/lora/v6/cuda/wkv6infctx_op.cpp",
f"finetune/lora/v6/cuda/wkv6infctx_cuda.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={HEAD_SIZE}",
f"-D_T_={int(os.environ['RWKV_CTXLEN'])}",
],
)
class WKV_6STATE(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, H, r, k, v, w, u, s):
with torch.no_grad():
assert r.dtype == torch.bfloat16
assert k.dtype == torch.bfloat16
assert v.dtype == torch.bfloat16
assert w.dtype == torch.bfloat16
assert u.dtype == torch.bfloat16
assert s.dtype == torch.bfloat16
assert HEAD_SIZE == C // H
ctx.B = B
ctx.T = T
ctx.C = C
ctx.H = H
assert r.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert w.is_contiguous()
assert u.is_contiguous()
assert s.is_contiguous()
ctx.save_for_backward(r, k, v, w, u, s)
y = torch.empty(
(B, T, C),
device=r.device,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
wkv6state_cuda.forward(B, T, C, H, r, k, v, w, u, s, y)
return y
@staticmethod
def backward(ctx, gy):
with torch.no_grad():
assert gy.dtype == torch.bfloat16
B = ctx.B
T = ctx.T
C = ctx.C
H = ctx.H
assert gy.is_contiguous()
r, k, v, w, u, s = ctx.saved_tensors
gr = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gk = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gv = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gw = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gu = torch.empty(
(B, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gs = torch.empty(
(B, H, C // H, C // H),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
wkv6state_cuda.backward(
B, T, C, H, r, k, v, w, u, s, gy, gr, gk, gv, gw, gu, gs
)
gu = torch.sum(gu, 0).view(H, C // H)
gs = torch.sum(gs, 0).view(H, C // H, C // H)
return (None, None, None, None, gr, gk, gv, gw, gu, gs)
def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s):
x = WKV_6STATE.apply(B, T, C, H, r, k, v, w, u, s)
return x, s
elif os.environ["RWKV_TRAIN_TYPE"] == "states":
wkv6state_cuda = load(
name="wkv6state",
sources=[
"finetune/lora/v6/cuda/wkv6state_op.cpp",
f"finetune/lora/v6/cuda/wkv6state_cuda.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={HEAD_SIZE}",
f"-D_T_={int(os.environ['RWKV_CTXLEN'])}",
],
)
class WKV_6STATE(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, H, r, k, v, w, u, s):
with torch.no_grad():
assert r.dtype == torch.bfloat16
assert k.dtype == torch.bfloat16
assert v.dtype == torch.bfloat16
assert w.dtype == torch.bfloat16
assert u.dtype == torch.bfloat16
assert s.dtype == torch.bfloat16
assert HEAD_SIZE == C // H
ctx.B = B
ctx.T = T
ctx.C = C
ctx.H = H
assert r.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert w.is_contiguous()
assert u.is_contiguous()
assert s.is_contiguous()
ctx.save_for_backward(r, k, v, w, u, s)
y = torch.empty(
(B, T, C),
device=r.device,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
wkv6state_cuda.forward(B, T, C, H, r, k, v, w, u, s, y)
return y
@staticmethod
def backward(ctx, gy):
with torch.no_grad():
assert gy.dtype == torch.bfloat16
B = ctx.B
T = ctx.T
C = ctx.C
H = ctx.H
assert gy.is_contiguous()
r, k, v, w, u, s = ctx.saved_tensors
gr = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gk = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gv = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gw = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gu = torch.empty(
(B, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gs = torch.empty(
(B, H, C // H, C // H),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
wkv6state_cuda.backward(
B, T, C, H, r, k, v, w, u, s, gy, gr, gk, gv, gw, gu, gs
)
gu = torch.sum(gu, 0).view(H, C // H)
gs = torch.sum(gs, 0).view(H, C // H, C // H)
return (None, None, None, None, gr, gk, gv, gw, gu, gs)
def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s):
return WKV_6STATE.apply(B, T, C, H, r, k, v, w, u, s)
else:
wkv6_cuda = load(
name="wkv6",
sources=[
"finetune/lora/v6/cuda/wkv6_op.cpp",
f"finetune/lora/v6/cuda/wkv6_cuda.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={HEAD_SIZE}",
f"-D_T_={int(os.environ['RWKV_CTXLEN'])}",
],
)
class WKV_6(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, H, r, k, v, w, u):
with torch.no_grad():
assert r.dtype == torch.bfloat16
assert k.dtype == torch.bfloat16
assert v.dtype == torch.bfloat16
assert w.dtype == torch.bfloat16
assert u.dtype == torch.bfloat16
assert HEAD_SIZE == C // H
ctx.B = B
ctx.T = T
ctx.C = C
ctx.H = H
assert r.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert w.is_contiguous()
assert u.is_contiguous()
ew = (-torch.exp(w.float())).contiguous()
ctx.save_for_backward(r, k, v, ew, u)
y = torch.empty(
(B, T, C),
device=r.device,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y)
return y
@staticmethod
def backward(ctx, gy):
with torch.no_grad():
assert gy.dtype == torch.bfloat16
B = ctx.B
T = ctx.T
C = ctx.C
H = ctx.H
assert gy.is_contiguous()
r, k, v, ew, u = ctx.saved_tensors
gr = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gk = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gv = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gw = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
gu = torch.empty(
(B, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-100, 100)
wkv6_cuda.backward(
B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu
)
gu = torch.sum(gu, 0).view(H, C // H)
return (None, None, None, None, gr, gk, gv, gw, gu)
def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
return WKV_6.apply(B, T, C, H, r, k, v, w, u)
else:
wkv5_cuda = load(
name="wkv5",
sources=[
"finetune/lora/v6/cuda/wkv5_op.cpp",
f"finetune/lora/v6/cuda/wkv5_cuda.cu",
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={HEAD_SIZE}",
],
)
class WKV_5(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, H, r, k, v, w, u):
with torch.no_grad():
assert r.dtype == torch.bfloat16
assert k.dtype == torch.bfloat16
assert v.dtype == torch.bfloat16
assert w.dtype == torch.bfloat16
assert u.dtype == torch.bfloat16
assert HEAD_SIZE == C // H
ctx.B = B
ctx.T = T
ctx.C = C
ctx.H = H
assert r.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert w.is_contiguous()
assert u.is_contiguous()
ew = (-torch.exp(w.float())).contiguous()
eew = (torch.exp(ew)).contiguous()
ctx.save_for_backward(r, k, v, eew, ew, u)
y = torch.empty(
(B, T, C),
device=r.device,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
return y
@staticmethod
def backward(ctx, gy):
with torch.no_grad():
assert gy.dtype == torch.bfloat16
B = ctx.B
T = ctx.T
C = ctx.C
H = ctx.H
assert gy.is_contiguous()
r, k, v, eew, ew, u = ctx.saved_tensors
gr = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
gk = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
gv = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
gw = torch.empty(
(B, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
gu = torch.empty(
(B, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
wkv5_cuda.backward(
B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu
)
gw = torch.sum(gw, 0).view(H, C // H)
gu = torch.sum(gu, 0).view(H, C // H)
return (None, None, None, None, gr, gk, gv, gw, gu)
def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
return WKV_5.apply(B, T, C, H, r, k, v, w, u)
########################################################################################################
class RWKV_TimeMix_RWKV5(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.head_size = args.head_size_a
assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
self.n_head = args.dim_att // self.head_size
assert args.dim_att % self.n_head == 0
self.head_size_divisor = args.head_size_divisor
with torch.no_grad():
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
# fancy time_mix
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(
torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
)
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for n in range(args.dim_att):
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
0.7 + 1.3 * ratio_0_to_1
)
self.time_decay = nn.Parameter(
decay_speed.reshape(self.n_head, self.head_size)
)
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
tmp = torch.zeros(args.dim_att)
for n in range(args.dim_att):
zigzag = ((n + 1) % 3 - 1) * 0.1
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.output = make_linear_att(args.dim_att, args.n_embd, bias=False)
self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
@MyFunction
def jit_func(self, x):
B, T, C = x.size()
xx = self.time_shift(
x
) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
r = self.receptance(xr)
k = self.key(xk)
v = self.value(xv)
g = F.silu(self.gate(xg))
return r, k, v, g
@MyFunction
def jit_func_2(self, x, g):
B, T, C = x.size()
x = x.view(B * T, C)
x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
x = self.output(x * g)
return x
def forward(self, x):
B, T, C = x.size()
H = self.n_head
r, k, v, g = self.jit_func(x)
x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
return self.jit_func_2(x, g)
class RWKV_Tmix_x060(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.head_size = args.head_size_a
self.n_head = args.dim_att // self.head_size
assert args.dim_att % self.n_head == 0
with torch.no_grad():
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
# fancy time_mix
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_v = nn.Parameter(
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
)
self.time_maa_r = nn.Parameter(
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
)
self.time_maa_g = nn.Parameter(
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
)
TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g
if args.n_embd == 4096:
TIME_MIX_EXTRA_DIM = TIME_MIX_EXTRA_DIM * 2
self.time_maa_w1 = nn.Parameter(
torch.zeros(args.n_embd, TIME_MIX_EXTRA_DIM * 5).uniform_(-1e-4, 1e-4)
)
self.time_maa_w2 = nn.Parameter(
torch.zeros(5, TIME_MIX_EXTRA_DIM, args.n_embd).uniform_(-1e-4, 1e-4)
)
# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for n in range(args.dim_att):
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
0.7 + 1.3 * ratio_0_to_1
)
self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
TIME_DECAY_EXTRA_DIM = 64
if args.n_embd == 4096:
TIME_DECAY_EXTRA_DIM = TIME_DECAY_EXTRA_DIM * 2
self.time_decay_w1 = nn.Parameter(
torch.zeros(args.n_embd, TIME_DECAY_EXTRA_DIM).uniform_(-1e-4, 1e-4)
)
self.time_decay_w2 = nn.Parameter(
torch.zeros(TIME_DECAY_EXTRA_DIM, args.dim_att).uniform_(-1e-4, 1e-4)
)
tmp = torch.zeros(args.dim_att)
for n in range(args.dim_att):
zigzag = ((n + 1) % 3 - 1) * 0.1
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.output = make_linear_att(args.dim_att, args.n_embd, bias=False)
self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.ln_x = nn.GroupNorm(
self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
)
@MyFunction
def jit_func(self, x):
B, T, C = x.size()
xx = self.time_shift(x) - x
xxx = x + xx * self.time_maa_x
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
xw = x + xx * (self.time_maa_w + mw)
xk = x + xx * (self.time_maa_k + mk)
xv = x + xx * (self.time_maa_v + mv)
xr = x + xx * (self.time_maa_r + mr)
xg = x + xx * (self.time_maa_g + mg)
r = self.receptance(xr)
k = self.key(xk)
v = self.value(xv)
g = F.silu(self.gate(xg))
ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
w = self.time_decay + ww
return r, k, v, g, w
@MyFunction
def jit_func_2(self, x, g):
B, T, C = x.size()
x = x.view(B * T, C)
x = self.ln_x(x).view(B, T, C)
x = self.output(x * g)
return x
def forward(self, x):
B, T, C = x.size()
H = self.n_head
r, k, v, g, w = self.jit_func(x)
x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
return self.jit_func_2(x, g)
########################################################################################################
class RWKV_Tmix_x060_state(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.head_size = args.head_size_a
self.n_head = args.dim_att // self.head_size
assert args.dim_att % self.n_head == 0
with torch.no_grad():
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
# fancy time_mix
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_v = nn.Parameter(
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
)
self.time_maa_r = nn.Parameter(
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
)
self.time_maa_g = nn.Parameter(
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
)
D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
if args.n_embd == 4096:
D_MIX_LORA = D_MIX_LORA * 2
self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
self.time_maa_w2 = nn.Parameter(
torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
)
# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for n in range(args.dim_att):
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
0.7 + 1.3 * ratio_0_to_1
)
self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
D_DECAY_LORA = 64
if args.n_embd == 4096:
D_DECAY_LORA = D_DECAY_LORA * 2
self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
self.time_decay_w2 = nn.Parameter(
torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
)
tmp = torch.zeros(args.dim_att)
for n in range(args.dim_att):
zigzag = ((n + 1) % 3 - 1) * 0.1
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
self.time_state = nn.Parameter(
torch.zeros(self.n_head, self.head_size, self.head_size)
)
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.ln_x = nn.GroupNorm(
self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
)
@MyFunction
def jit_func(self, x):
B, T, C = x.size()
xx = self.time_shift(x) - x
xxx = x + xx * self.time_maa_x
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
xw = x + xx * (self.time_maa_w + mw)
xk = x + xx * (self.time_maa_k + mk)
xv = x + xx * (self.time_maa_v + mv)
xr = x + xx * (self.time_maa_r + mr)
xg = x + xx * (self.time_maa_g + mg)
r = self.receptance(xr)
k = self.key(xk)
v = self.value(xv)
g = F.silu(self.gate(xg))
ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
w = self.time_decay + ww
return r, k, v, g, w
@MyFunction
def jit_func_2(self, x, g):
B, T, C = x.size()
x = x.view(B * T, C)
x = self.ln_x(x).view(B, T, C)
x = self.output(x * g)
return x
def forward(self, x):
B, T, C = x.size()
H = self.n_head
r, k, v, g, w = self.jit_func(x)
x = RUN_CUDA_RWKV6_STATE(
B, T, C, H, r, k, v, w, u=self.time_faaaa, s=self.time_state
)
return self.jit_func_2(x, g)
########################################################################################################
class RWKV_ChannelMix(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.relu(k) ** 2
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv
class RWKV_CMix_x060(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x) - x
xk = x + xx * self.time_maa_k
xr = x + xx * self.time_maa_r
k = self.key(xk)
k = torch.relu(k) ** 2
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv
########################################################################################################
class MishGLU(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad():
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
a = self.aa(xa)
b = self.bb(xb)
return self.value(a * F.mish(b))
########################################################################################################
class RWKV_Tmix_x060_infctx(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.head_size = args.head_size_a
self.n_head = args.dim_att // self.head_size
assert args.dim_att % self.n_head == 0
with torch.no_grad():
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
# fancy time_mix
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_v = nn.Parameter(
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
)
self.time_maa_r = nn.Parameter(
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
)
self.time_maa_g = nn.Parameter(
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
)
D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
if args.n_embd == 4096:
D_MIX_LORA = D_MIX_LORA * 2
self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
self.time_maa_w2 = nn.Parameter(
torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
)
# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for n in range(args.dim_att):
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
0.7 + 1.3 * ratio_0_to_1
)
self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
D_DECAY_LORA = 64
if args.n_embd == 4096:
D_DECAY_LORA = D_DECAY_LORA * 2
self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
self.time_decay_w2 = nn.Parameter(
torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
)
tmp = torch.zeros(args.dim_att)
for n in range(args.dim_att):
zigzag = ((n + 1) % 3 - 1) * 0.1
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
# self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.ln_x = nn.GroupNorm(
self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
)
@MyFunction
def jit_func(self, x, shift_state):
B, T, C = x.size()
xx = torch.concat((shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
xxx = x + xx * self.time_maa_x
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
xw = x + xx * (self.time_maa_w + mw)
xk = x + xx * (self.time_maa_k + mk)
xv = x + xx * (self.time_maa_v + mv)
xr = x + xx * (self.time_maa_r + mr)
xg = x + xx * (self.time_maa_g + mg)
r = self.receptance(xr)
k = self.key(xk)
v = self.value(xv)
g = F.silu(self.gate(xg))
ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
w = self.time_decay + ww
return r, k, v, g, w, x[:, -1]
@MyFunction
def jit_func_2(self, x, g, timemixstate: TimeMixState):
B, T, C = x.size()
x = x.view(B * T, C)
x = self.ln_x(x).view(B, T, C)
x = self.output(x * g)
return x, timemixstate
def forward(self, x, last_state: TimeMixState):
B, T, C = x.size()
H = self.n_head
shift_state = last_state.shift_state
r, k, v, g, w, lx = self.jit_func(x, shift_state)
######
wkv_state = last_state.wkv_state.clone().contiguous()
x, wkv_state = RUN_CUDA_RWKV6_STATE(
B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
)
# wkv_state = last_state.wkv_state
return self.jit_func_2(x, g, TimeMixState(lx, wkv_state))
class RWKV_CMix_x060_infctx(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x, last_state: ChannelMixState):
xx = torch.concat((last_state.shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
xk = x + xx * self.time_maa_k
xr = x + xx * self.time_maa_r
k = self.key(xk)
k = torch.relu(k) ** 2
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv, ChannelMixState(x[:, -1])
########################################################################################################
# The RWKV Model with our blocks
########################################################################################################
class Block(nn.Module):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.ln1 = nn.LayerNorm(args.n_embd)
self.ln2 = nn.LayerNorm(args.n_embd)
if self.layer_id == 0:
self.ln0 = nn.LayerNorm(args.n_embd)
if args.my_pos_emb > 0:
self.pos_emb_x = nn.Parameter(
torch.zeros((1, args.my_pos_emb, args.n_embd))
)
self.pos_emb_y = nn.Parameter(
torch.zeros((args.my_pos_emb, 1, args.n_embd))
)
if self.layer_id == 0 and self.args.pre_ffn > 0:
self.ffnPre = RWKV_ChannelMix(args, 0)
else:
if "x060" in os.environ["RWKV_MY_TESTING"]:
if os.environ["RWKV_TRAIN_TYPE"] == "states":
self.att = RWKV_Tmix_x060_state(args, layer_id)
elif os.environ["RWKV_TRAIN_TYPE"] == "infctx":
self.att = RWKV_Tmix_x060_infctx(args, layer_id)
else:
self.att = RWKV_Tmix_x060(args, layer_id)
else:
self.att = RWKV_TimeMix_RWKV5(args, layer_id)
if "g" in os.environ["RWKV_MY_TESTING"]:
self.ffn = MishGLU(args, layer_id)
else:
if "x060" in os.environ["RWKV_MY_TESTING"]:
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
self.ffn = RWKV_CMix_x060_infctx(args, layer_id)
else:
self.ffn = RWKV_CMix_x060(args, layer_id)
else:
self.ffn = RWKV_ChannelMix(args, layer_id)
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
self.tiny_ln = nn.LayerNorm(args.n_embd)
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.register_buffer(
"tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
)
if args.dropout > 0:
self.drop0 = nn.Dropout(p=args.dropout)
self.drop1 = nn.Dropout(p=args.dropout)
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
def forward(self, x, last_state: BlockState, x_emb=None):
args = self.args
B, T, C = x.size()
if self.layer_id == 0:
x = self.ln0(x)
if args.my_pos_emb > 0:
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[
:-1, :
]
x = x + pos_emb
if self.args.dropout == 0:
if self.layer_id == 0 and args.pre_ffn > 0:
x = x + self.ffnPre(self.ln1(x))
else:
att_out, att_state = self.att(
self.ln1(x), last_state.time_mix_state
)
x = x + att_out
ffn_out, fnn_state = self.ffn(self.ln2(x), last_state.channel_mix_state)
x = x + ffn_out
else:
if self.layer_id == 0 and args.pre_ffn > 0:
x = self.drop0(x + self.ffnPre(self.ln1(x)))
else:
x = self.drop0(x + self.att(self.ln1(x)))
x = self.drop1(x + self.ffn(self.ln2(x)))
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
xx = self.tiny_ln(x)
q = self.tiny_q(xx)[:, :T, :]
k = self.tiny_k(xx)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
x = x + c @ self.tiny_v(x_emb)
return x, BlockState(att_state, fnn_state)
else:
def forward(self, x, x_emb=None):
args = self.args
B, T, C = x.size()
if self.layer_id == 0:
x = self.ln0(x)
if args.my_pos_emb > 0:
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[
:-1, :
]
x = x + pos_emb
if self.args.dropout == 0:
if self.layer_id == 0 and args.pre_ffn > 0:
x = x + self.ffnPre(self.ln1(x))
else:
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))
else:
if self.layer_id == 0 and args.pre_ffn > 0:
x = self.drop0(x + self.ffnPre(self.ln1(x)))
else:
x = self.drop0(x + self.att(self.ln1(x)))
x = self.drop1(x + self.ffn(self.ln2(x)))
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
xx = self.tiny_ln(x)
q = self.tiny_q(xx)[:, :T, :]
k = self.tiny_k(xx)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
x = x + c @ self.tiny_v(x_emb)
return x
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y, token_amount):
ctx.save_for_backward(y)
ctx.token_amount = token_amount
return loss
@staticmethod
def backward(
ctx, grad_output
): # 这个函数会不会影响batch和grad_accu的一致性感觉上会。梯度累积时factor变大了。但是只有loss缩放这里的正则化项反而没有缩放
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
if ctx.token_amount == 0:
return (grad_output, None, None)
factor = 1e-4 / ctx.token_amount # 这一行类似crossentropy在token上平均。
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
if os.environ.get("WN_FIX_L2WRAP"): # 实现batch等价性
# maxx[maxx<3.]=0. #防止对已经较小的logits值下拉只对大于阈值的往下拉
gy.scatter_(-1, ids, maxx * factor * grad_output)
else:
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy, None)
else:
class L2Wrap(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, y):
ctx.save_for_backward(y)
return loss
@staticmethod
def backward(ctx, grad_output):
y = ctx.saved_tensors[0]
# to encourage the logits to be close to 0
factor = 1e-4 / (y.shape[0] * y.shape[1])
maxx, ids = torch.max(y, -1, keepdim=True)
gy = torch.zeros_like(y)
gy.scatter_(-1, ids, maxx * factor)
return (grad_output, gy)
class RWKV(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
if not hasattr(args, "dim_att"):
args.dim_att = args.n_embd
if not hasattr(args, "dim_ffn"):
args.dim_ffn = args.n_embd * 4
if not hasattr(args, "tiny_att_layer"):
args.tiny_att_layer = -1
if not hasattr(args, "tiny_att_dim"):
args.tiny_att_dim = -1
assert args.n_embd % 32 == 0
assert args.dim_att % 32 == 0
assert args.dim_ffn % 32 == 0
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
self.ln_out = nn.LayerNorm(args.n_embd)
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
if args.head_qk > 0:
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
self.register_buffer(
"copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
)
if args.dropout > 0:
self.drop0 = nn.Dropout(p=args.dropout)
def configure_optimizers(self):
args = self.args
lr_decay = set()
lr_1x = set()
lr_2x = set()
lr_3x = set()
for n, p in self.named_parameters():
if not p.requires_grad:
continue
if (("_w1" in n) or ("_w2" in n)) and (args.layerwise_lr > 0):
lr_1x.add(n)
elif (("time_mix" in n) or ("time_maa" in n)) and (args.layerwise_lr > 0):
if args.my_pile_stage == 2:
lr_2x.add(n)
else:
lr_1x.add(n)
elif (("time_decay" in n) or ("time_daaaa" in n)) and (
args.layerwise_lr > 0
):
if args.my_pile_stage == 2:
lr_3x.add(n)
else:
lr_2x.add(n)
elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
if args.my_pile_stage == 2:
lr_2x.add(n)
else:
lr_1x.add(n)
elif ("time_first" in n) and (args.layerwise_lr > 0):
lr_3x.add(n)
elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
lr_decay.add(n)
else:
lr_1x.add(n)
lr_decay = sorted(list(lr_decay))
lr_1x = sorted(list(lr_1x))
lr_2x = sorted(list(lr_2x))
lr_3x = sorted(list(lr_3x))
# print('decay', lr_decay)
# print('1x', lr_1x)
# print('2x', lr_2x)
# print('3x', lr_3x)
param_dict = {n: p for n, p in self.named_parameters()}
if args.layerwise_lr > 0:
if args.my_pile_stage == 2:
optim_groups = [
{
"params": [param_dict[n] for n in lr_1x],
"weight_decay": 0.0,
"my_lr_scale": 1.0,
},
{
"params": [param_dict[n] for n in lr_2x],
"weight_decay": 0.0,
"my_lr_scale": 5.0,
}, # test: 2e-3 / args.lr_init},
{
"params": [param_dict[n] for n in lr_3x],
"weight_decay": 0.0,
"my_lr_scale": 5.0,
}, # test: 3e-3 / args.lr_init},
]
else:
optim_groups = [
{
"params": [param_dict[n] for n in lr_1x],
"weight_decay": 0.0,
"my_lr_scale": 1.0,
},
{
"params": [param_dict[n] for n in lr_2x],
"weight_decay": 0.0,
"my_lr_scale": 2.0,
},
{
"params": [param_dict[n] for n in lr_3x],
"weight_decay": 0.0,
"my_lr_scale": 3.0,
},
]
else:
optim_groups = [
{
"params": [param_dict[n] for n in lr_1x],
"weight_decay": 0.0,
"my_lr_scale": 1.0,
}
]
if args.weight_decay > 0:
optim_groups += [
{
"params": [param_dict[n] for n in lr_decay],
"weight_decay": args.weight_decay,
"my_lr_scale": 1.0,
}
]
if self.deepspeed_offload:
return DeepSpeedCPUAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adamw_mode=True,
amsgrad=False,
)
return FusedAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adam_w_mode=True,
amsgrad=False,
)
else:
if self.deepspeed_offload:
return DeepSpeedCPUAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adamw_mode=False,
weight_decay=0,
amsgrad=False,
)
return FusedAdam(
optim_groups,
lr=self.args.lr_init,
betas=self.args.betas,
eps=self.args.adam_eps,
bias_correction=True,
adam_w_mode=False,
weight_decay=0,
amsgrad=False,
)
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
if isinstance(strategy, DeepSpeedStrategy):
cfg = strategy.config["zero_optimization"]
return cfg.get("offload_optimizer") or cfg.get("offload_param")
return False
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
def forward(
self, idx, last_shift_states: torch.Tensor, last_wkv_states: torch.Tensor
):
args = self.args
B, T = idx.size()
assert T <= args.chunk_ctx, "Cannot forward, model ctx_len is exhausted."
C = args.n_embd
H = args.dim_att // args.head_size_a
assert C == H * args.head_size_a
x = self.emb(idx)
x_emb = x
new_states = BlockStateList.empty(
args.n_layer, B, args.n_embd, H, x.device, x.dtype
)
if args.dropout > 0:
x = self.drop0(x)
for i, (block, block_state) in enumerate(
zip(self.blocks, BlockStateList(last_shift_states, last_wkv_states))
):
# x = x.to(block.device)
if args.grad_cp == 1 and i > 0: # and i < len(self.blocks)-1
x, new_block_state = torch_checkpoint(
block, x, block_state, use_reentrant=False
)
else:
x, new_block_state = block(x, block_state)
new_states[i] = new_block_state
x = self.ln_out(x)
if args.head_qk > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)
return x, new_states.shift_states, new_states.wkv_states
def training_step(self, batch, batch_idx):
args = self.args
T_train = args.chunk_ctx
idx, targets = batch
B, T = idx.shape
C = args.n_embd
H = args.dim_att // args.head_size_a
assert C == H * args.head_size_a
states = BlockStateList.create(
args.n_layer, B, C, H, idx.device, self.emb.weight.dtype
)
def checkpointed_step(
idx,
targets,
prev_loss,
last_shift_states,
last_wkv_states,
prev_token_amount,
):
logits, new_shift_states, new_wkv_states = self(
idx, last_shift_states, last_wkv_states
)
current_token_amount = (targets != -100).sum() # 这样是不是更合适?
current_token_amount = idx.shape[1]
if current_token_amount == 0:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.reshape(-1),
reduction="sum",
)
else:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.reshape(-1)
)
loss = L2Wrap.apply(loss, logits, current_token_amount)
new_token_amount = prev_token_amount + current_token_amount
if new_token_amount > 0:
new_loss = prev_loss * (
prev_token_amount / new_token_amount
) + loss * (current_token_amount / new_token_amount)
else:
new_loss = prev_loss
return new_loss, new_shift_states, new_wkv_states, new_token_amount
total_loss = torch.tensor(0.0, dtype=self.emb.weight.dtype).requires_grad_()
token_amount = 0
i = 0
for i in range(math.ceil(T / T_train)):
# states.shift_states = states.shift_states.cuda()
# states.wkv_states = states.wkv_states.cuda()
total_loss, new_shift_states, new_wkv_states, token_amount = (
torch_checkpoint(
checkpointed_step,
idx[:, i * T_train : (i + 1) * T_train],
targets[:, i * T_train : (i + 1) * T_train],
total_loss,
states.shift_states,
states.wkv_states,
token_amount,
use_reentrant=False,
)
)
# total_loss,new_shift_states, new_wkv_states,token_amount = checkpointed_step(
# idx[:, i * T_train:(i + 1) * T_train],
# targets[:, i * T_train:(i + 1) * T_train],
# total_loss,
# states.shift_states,
# states.wkv_states,
# token_amount
# )
# new_shift_states = new_shift_states.cpu()
# new_wkv_states = new_wkv_states.cpu()
states = BlockStateList(new_shift_states, new_wkv_states)
return total_loss
else:
def forward(self, idx):
args = self.args
B, T = idx.size()
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
x = self.emb(idx)
x_emb = x
if args.dropout > 0:
x = self.drop0(x)
if args.tiny_att_dim > 0:
for block in self.blocks:
if args.grad_cp == 1:
if args.lora or args.state_tune or args.train_type == "state":
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
else:
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
else:
x = block(x, x_emb)
else:
for block in self.blocks:
if args.grad_cp == 1:
if args.lora or args.state_tune or args.train_type == "state":
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
else:
x = deepspeed.checkpointing.checkpoint(block, x)
else:
x = block(x)
x = self.ln_out(x)
if args.head_qk > 0:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
x = self.head(x) + c
else:
x = self.head(x)
return x
def training_step(self, batch, batch_idx):
args = self.args
if args.my_qa_mask != 1:
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1)
)
# if '0' in os.environ["RWKV_MY_TESTING"]:
# print('logits', logits)
# torch.set_printoptions(threshold=10000)
# print('idx', idx)
# exit(0)
else:
idx, targets, mask = batch
mask = mask.view(-1)
sum_mask = torch.sum(mask).item()
# if sum_mask == 0:
# return torch.tensor([0.0], requires_grad=True)
logits = self(idx)
if sum_mask == mask.shape[0]:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1)
)
# print('rank', self.global_rank, 'loss', loss.item())
else:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
reduction="none",
)
# loss_raw = loss
loss = torch.sum(loss * mask) / sum_mask
# torch.set_printoptions(threshold=10000)
# if True: #self.global_rank == 1:
# tmp = ''
# sss = 0
# ccc = 0
# for i in range(mask.shape[0]):
# if mask[i] > 0:
# tmp += str(idx.view(-1)[i].item()) + ','
# sss += loss_raw.view(-1)[i].float().item()
# ccc += 1
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
return L2Wrap.apply(loss, logits)
def training_step_end(self, batch_parts):
if pl.__version__[0] != "2":
all = self.all_gather(batch_parts)
if self.trainer.is_global_zero:
self.trainer.my_loss_all = all
def generate_init_weight(self):
print(
f"""
############################################################################
#
# Init model weight (slow for large models)...
#
############################################################################
"""
)
m = {}
for n in self.state_dict():
p = self.state_dict()[n]
shape = p.shape
gain = 1.0
scale = 1.0
if (
"ln_" in n
or ".ln" in n
or "time_" in n
or "_mask" in n
or "pos_emb" in n
or ".mask." in n
):
if "ln_x.weight" in n:
layer_scale = (1 + int(n.split(".")[1])) / self.args.n_layer
m[n] = (p * 0.0) + (layer_scale**0.7)
else:
m[n] = p
else:
if n == "emb.weight":
scale = -1 * self.args.lr_init
else:
if shape[0] > shape[1]:
gain = math.sqrt(shape[0] / shape[1])
zero = [
".att.output.",
".ffn.value.",
".ffn.receptance.",
".ffnPre.value.",
".ffnPre.receptance.",
"head_q.",
".oo.",
".rr.",
]
for kk in zero:
if kk in n:
scale = 0
if n == "head.weight":
scale = 0.5
if "head_k." in n:
scale = 0.1
if "head_q." in n:
scale = 0
print(
f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}"
)
if self.args.accelerator.upper() == "GPU":
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
else:
m[n] = torch.empty((shape[0], shape[1]))
if scale == 0:
nn.init.zeros_(m[n])
elif scale < 0:
nn.init.uniform_(m[n], a=scale, b=-scale)
else:
nn.init.orthogonal_(m[n], gain=gain * scale)
m[n] = m[n].cpu()
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
m[n] = m[n].half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
m[n] = m[n].bfloat16()
# if n == "emb.weight":
# print(m[n])
gc.collect()
torch.cuda.empty_cache()
return m