1086 lines
39 KiB
Python
Vendored
1086 lines
39 KiB
Python
Vendored
########################################################################################################
|
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
|
########################################################################################################
|
|
import functools
|
|
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
|
|
|
import os, math, gc, importlib
|
|
import torch
|
|
|
|
# torch._C._jit_set_profiling_executor(True)
|
|
# torch._C._jit_set_profiling_mode(True)
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
|
from pytorch_lightning.strategies import DeepSpeedStrategy
|
|
|
|
if importlib.util.find_spec("deepspeed"):
|
|
import deepspeed
|
|
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
|
|
|
|
|
LORA_CONFIG = {
|
|
"r": 0,
|
|
"alpha": 0,
|
|
"dropout": 0,
|
|
"parts": {"att", "ln", "time", "ffn"},
|
|
}
|
|
|
|
|
|
class LoraLinear(nn.Module):
|
|
def __init__(self, in_features: int, out_features: int, bias: bool):
|
|
super().__init__()
|
|
|
|
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
|
assert bias == False, "Biased LoraLinear not supported"
|
|
|
|
r, alpha, dropout = (
|
|
LORA_CONFIG["r"],
|
|
LORA_CONFIG["alpha"],
|
|
LORA_CONFIG["dropout"],
|
|
)
|
|
self.lora_A = nn.Parameter(torch.empty(r, in_features))
|
|
self.lora_B = nn.Parameter(torch.empty(out_features, r))
|
|
self.lora_dropout = nn.Dropout(dropout)
|
|
self.scaling = alpha / r
|
|
|
|
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
nn.init.zeros_(self.lora_B)
|
|
|
|
def forward(self, x):
|
|
return F.linear(x, self.weight) + self.scaling * F.linear(
|
|
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
|
|
)
|
|
|
|
|
|
@functools.wraps(LoraLinear)
|
|
def make_linear_att(*args, **kwargs):
|
|
if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
|
return LoraLinear(*args, **kwargs)
|
|
else:
|
|
return nn.Linear(*args, **kwargs)
|
|
|
|
|
|
@functools.wraps(LoraLinear)
|
|
def make_linear_ffn(*args, **kwargs):
|
|
if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
|
return LoraLinear(*args, **kwargs)
|
|
else:
|
|
return nn.Linear(*args, **kwargs)
|
|
|
|
|
|
try:
|
|
print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"])
|
|
except:
|
|
os.environ["RWKV_MY_TESTING"] = ""
|
|
|
|
|
|
def __nop(ob):
|
|
return ob
|
|
|
|
|
|
MyModule = nn.Module
|
|
MyFunction = __nop
|
|
if os.environ["RWKV_JIT_ON"] == "1":
|
|
MyModule = torch.jit.ScriptModule
|
|
MyFunction = torch.jit.script_method
|
|
|
|
|
|
########################################################################################################
|
|
# CUDA Kernel
|
|
########################################################################################################
|
|
|
|
from torch.utils.cpp_extension import load
|
|
|
|
HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"])
|
|
|
|
if "x060" in os.environ["RWKV_MY_TESTING"]:
|
|
wkv6_cuda = load(
|
|
name="wkv6",
|
|
sources=[
|
|
"finetune/lora/v6/cuda/wkv6_op.cpp",
|
|
f"finetune/lora/v6/cuda/wkv6_cuda.cu",
|
|
],
|
|
verbose=True,
|
|
extra_cuda_cflags=[
|
|
"-res-usage",
|
|
"--use_fast_math",
|
|
"-O3",
|
|
"-Xptxas -O3",
|
|
"--extra-device-vectorization",
|
|
f"-D_N_={HEAD_SIZE}",
|
|
f"-D_T_={int(os.environ['RWKV_CTXLEN'])}",
|
|
],
|
|
)
|
|
|
|
class WKV_6(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, B, T, C, H, r, k, v, w, u):
|
|
with torch.no_grad():
|
|
assert r.dtype == torch.bfloat16
|
|
assert k.dtype == torch.bfloat16
|
|
assert v.dtype == torch.bfloat16
|
|
assert w.dtype == torch.bfloat16
|
|
assert u.dtype == torch.bfloat16
|
|
assert HEAD_SIZE == C // H
|
|
ctx.B = B
|
|
ctx.T = T
|
|
ctx.C = C
|
|
ctx.H = H
|
|
assert r.is_contiguous()
|
|
assert k.is_contiguous()
|
|
assert v.is_contiguous()
|
|
assert w.is_contiguous()
|
|
assert u.is_contiguous()
|
|
ew = (-torch.exp(w.float())).contiguous()
|
|
ctx.save_for_backward(r, k, v, ew, u)
|
|
y = torch.empty(
|
|
(B, T, C),
|
|
device=r.device,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-100, 100)
|
|
wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, gy):
|
|
with torch.no_grad():
|
|
assert gy.dtype == torch.bfloat16
|
|
B = ctx.B
|
|
T = ctx.T
|
|
C = ctx.C
|
|
H = ctx.H
|
|
assert gy.is_contiguous()
|
|
r, k, v, ew, u = ctx.saved_tensors
|
|
gr = torch.empty(
|
|
(B, T, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-100, 100)
|
|
gk = torch.empty(
|
|
(B, T, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-100, 100)
|
|
gv = torch.empty(
|
|
(B, T, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-100, 100)
|
|
gw = torch.empty(
|
|
(B, T, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-100, 100)
|
|
gu = torch.empty(
|
|
(B, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-100, 100)
|
|
wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu)
|
|
gu = torch.sum(gu, 0).view(H, C // H)
|
|
return (None, None, None, None, gr, gk, gv, gw, gu)
|
|
|
|
def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
|
|
return WKV_6.apply(B, T, C, H, r, k, v, w, u)
|
|
|
|
else:
|
|
wkv5_cuda = load(
|
|
name="wkv5",
|
|
sources=[
|
|
"finetune/lora/v6/cuda/wkv5_op.cpp",
|
|
f"finetune/lora/v6/cuda/wkv5_cuda.cu",
|
|
],
|
|
verbose=True,
|
|
extra_cuda_cflags=[
|
|
"-res-usage",
|
|
"--use_fast_math",
|
|
"-O3",
|
|
"-Xptxas -O3",
|
|
"--extra-device-vectorization",
|
|
f"-D_N_={HEAD_SIZE}",
|
|
],
|
|
)
|
|
|
|
class WKV_5(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, B, T, C, H, r, k, v, w, u):
|
|
with torch.no_grad():
|
|
assert r.dtype == torch.bfloat16
|
|
assert k.dtype == torch.bfloat16
|
|
assert v.dtype == torch.bfloat16
|
|
assert w.dtype == torch.bfloat16
|
|
assert u.dtype == torch.bfloat16
|
|
assert HEAD_SIZE == C // H
|
|
ctx.B = B
|
|
ctx.T = T
|
|
ctx.C = C
|
|
ctx.H = H
|
|
assert r.is_contiguous()
|
|
assert k.is_contiguous()
|
|
assert v.is_contiguous()
|
|
assert w.is_contiguous()
|
|
assert u.is_contiguous()
|
|
ew = (-torch.exp(w.float())).contiguous()
|
|
eew = (torch.exp(ew)).contiguous()
|
|
ctx.save_for_backward(r, k, v, eew, ew, u)
|
|
y = torch.empty(
|
|
(B, T, C),
|
|
device=r.device,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-1, 1)
|
|
wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, gy):
|
|
with torch.no_grad():
|
|
assert gy.dtype == torch.bfloat16
|
|
B = ctx.B
|
|
T = ctx.T
|
|
C = ctx.C
|
|
H = ctx.H
|
|
assert gy.is_contiguous()
|
|
r, k, v, eew, ew, u = ctx.saved_tensors
|
|
gr = torch.empty(
|
|
(B, T, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-1, 1)
|
|
gk = torch.empty(
|
|
(B, T, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-1, 1)
|
|
gv = torch.empty(
|
|
(B, T, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-1, 1)
|
|
gw = torch.empty(
|
|
(B, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-1, 1)
|
|
gu = torch.empty(
|
|
(B, C),
|
|
device=gy.device,
|
|
requires_grad=False,
|
|
dtype=torch.bfloat16,
|
|
memory_format=torch.contiguous_format,
|
|
) # .uniform_(-1, 1)
|
|
wkv5_cuda.backward(
|
|
B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu
|
|
)
|
|
gw = torch.sum(gw, 0).view(H, C // H)
|
|
gu = torch.sum(gu, 0).view(H, C // H)
|
|
return (None, None, None, None, gr, gk, gv, gw, gu)
|
|
|
|
def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
|
|
return WKV_5.apply(B, T, C, H, r, k, v, w, u)
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
class RWKV_TimeMix_RWKV5(MyModule):
|
|
def __init__(self, args, layer_id):
|
|
super().__init__()
|
|
self.args = args
|
|
self.layer_id = layer_id
|
|
|
|
self.head_size = args.head_size_a
|
|
assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
|
|
self.n_head = args.dim_att // self.head_size
|
|
assert args.dim_att % self.n_head == 0
|
|
self.head_size_divisor = args.head_size_divisor
|
|
|
|
with torch.no_grad():
|
|
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
ddd = torch.ones(1, 1, args.n_embd)
|
|
for i in range(args.n_embd):
|
|
ddd[0, 0, i] = i / args.n_embd
|
|
|
|
# fancy time_mix
|
|
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
self.time_mix_v = nn.Parameter(
|
|
torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
|
)
|
|
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
|
self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
|
|
|
# fancy time_decay
|
|
decay_speed = torch.ones(args.dim_att)
|
|
for n in range(args.dim_att):
|
|
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
|
|
0.7 + 1.3 * ratio_0_to_1
|
|
)
|
|
self.time_decay = nn.Parameter(
|
|
decay_speed.reshape(self.n_head, self.head_size)
|
|
)
|
|
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
|
|
|
tmp = torch.zeros(args.dim_att)
|
|
for n in range(args.dim_att):
|
|
zigzag = ((n + 1) % 3 - 1) * 0.1
|
|
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
|
|
|
|
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
|
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
|
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
|
|
|
self.output = make_linear_att(args.dim_att, args.n_embd, bias=False)
|
|
self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
|
self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
|
|
|
|
@MyFunction
|
|
def jit_func(self, x):
|
|
B, T, C = x.size()
|
|
|
|
xx = self.time_shift(
|
|
x
|
|
) # Mix x with the previous timestep to produce xk, xv, xr
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
|
|
|
|
r = self.receptance(xr)
|
|
k = self.key(xk)
|
|
v = self.value(xv)
|
|
g = F.silu(self.gate(xg))
|
|
|
|
return r, k, v, g
|
|
|
|
@MyFunction
|
|
def jit_func_2(self, x, g):
|
|
B, T, C = x.size()
|
|
x = x.view(B * T, C)
|
|
|
|
x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
|
|
x = self.output(x * g)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
B, T, C = x.size()
|
|
H = self.n_head
|
|
|
|
r, k, v, g = self.jit_func(x)
|
|
|
|
x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
|
|
|
|
return self.jit_func_2(x, g)
|
|
|
|
|
|
class RWKV_Tmix_x060(MyModule):
|
|
def __init__(self, args, layer_id):
|
|
super().__init__()
|
|
self.args = args
|
|
self.layer_id = layer_id
|
|
|
|
self.head_size = args.head_size_a
|
|
self.n_head = args.dim_att // self.head_size
|
|
assert args.dim_att % self.n_head == 0
|
|
|
|
with torch.no_grad():
|
|
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
ddd = torch.ones(1, 1, args.n_embd)
|
|
for i in range(args.n_embd):
|
|
ddd[0, 0, i] = i / args.n_embd
|
|
|
|
# fancy time_mix
|
|
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
|
self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
|
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
|
self.time_maa_v = nn.Parameter(
|
|
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
|
)
|
|
self.time_maa_r = nn.Parameter(
|
|
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
|
|
)
|
|
self.time_maa_g = nn.Parameter(
|
|
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
|
|
)
|
|
|
|
TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g
|
|
self.time_maa_w1 = nn.Parameter(
|
|
torch.zeros(args.n_embd, TIME_MIX_EXTRA_DIM * 5).uniform_(-1e-4, 1e-4)
|
|
)
|
|
self.time_maa_w2 = nn.Parameter(
|
|
torch.zeros(5, TIME_MIX_EXTRA_DIM, args.n_embd).uniform_(-1e-4, 1e-4)
|
|
)
|
|
|
|
# fancy time_decay
|
|
decay_speed = torch.ones(args.dim_att)
|
|
for n in range(args.dim_att):
|
|
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
|
|
0.7 + 1.3 * ratio_0_to_1
|
|
)
|
|
self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
|
|
|
|
TIME_DECAY_EXTRA_DIM = 64
|
|
self.time_decay_w1 = nn.Parameter(
|
|
torch.zeros(args.n_embd, TIME_DECAY_EXTRA_DIM).uniform_(-1e-4, 1e-4)
|
|
)
|
|
self.time_decay_w2 = nn.Parameter(
|
|
torch.zeros(TIME_DECAY_EXTRA_DIM, args.dim_att).uniform_(-1e-4, 1e-4)
|
|
)
|
|
|
|
tmp = torch.zeros(args.dim_att)
|
|
for n in range(args.dim_att):
|
|
zigzag = ((n + 1) % 3 - 1) * 0.1
|
|
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
|
|
|
|
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
|
|
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
|
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
|
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
|
self.output = make_linear_att(args.dim_att, args.n_embd, bias=False)
|
|
self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
|
self.ln_x = nn.GroupNorm(
|
|
self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
|
|
)
|
|
|
|
@MyFunction
|
|
def jit_func(self, x):
|
|
B, T, C = x.size()
|
|
|
|
xx = self.time_shift(x) - x
|
|
|
|
xxx = x + xx * self.time_maa_x
|
|
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
|
|
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
|
|
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
|
|
|
|
xw = x + xx * (self.time_maa_w + mw)
|
|
xk = x + xx * (self.time_maa_k + mk)
|
|
xv = x + xx * (self.time_maa_v + mv)
|
|
xr = x + xx * (self.time_maa_r + mr)
|
|
xg = x + xx * (self.time_maa_g + mg)
|
|
|
|
r = self.receptance(xr)
|
|
k = self.key(xk)
|
|
v = self.value(xv)
|
|
g = F.silu(self.gate(xg))
|
|
|
|
ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
|
|
w = self.time_decay + ww
|
|
|
|
return r, k, v, g, w
|
|
|
|
@MyFunction
|
|
def jit_func_2(self, x, g):
|
|
B, T, C = x.size()
|
|
x = x.view(B * T, C)
|
|
|
|
x = self.ln_x(x).view(B, T, C)
|
|
x = self.output(x * g)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
B, T, C = x.size()
|
|
H = self.n_head
|
|
|
|
r, k, v, g, w = self.jit_func(x)
|
|
x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
|
|
|
|
return self.jit_func_2(x, g)
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
class RWKV_ChannelMix(MyModule):
|
|
def __init__(self, args, layer_id):
|
|
super().__init__()
|
|
self.args = args
|
|
self.layer_id = layer_id
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
with torch.no_grad(): # fancy init of time_mix
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
ddd = torch.ones(1, 1, args.n_embd)
|
|
for i in range(args.n_embd):
|
|
ddd[0, 0, i] = i / args.n_embd
|
|
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
|
|
|
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
|
|
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
|
|
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
|
|
|
|
@MyFunction
|
|
def forward(self, x):
|
|
xx = self.time_shift(x)
|
|
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
k = self.key(xk)
|
|
k = torch.relu(k) ** 2
|
|
kv = self.value(k)
|
|
return torch.sigmoid(self.receptance(xr)) * kv
|
|
|
|
|
|
class RWKV_CMix_x060(MyModule):
|
|
def __init__(self, args, layer_id):
|
|
super().__init__()
|
|
self.args = args
|
|
self.layer_id = layer_id
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
with torch.no_grad(): # fancy init of time_mix
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
|
ddd = torch.ones(1, 1, args.n_embd)
|
|
for i in range(args.n_embd):
|
|
ddd[0, 0, i] = i / args.n_embd
|
|
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
|
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
|
|
|
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
|
|
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
|
|
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
|
|
|
|
@MyFunction
|
|
def forward(self, x):
|
|
xx = self.time_shift(x) - x
|
|
xk = x + xx * self.time_maa_k
|
|
xr = x + xx * self.time_maa_r
|
|
|
|
k = self.key(xk)
|
|
k = torch.relu(k) ** 2
|
|
kv = self.value(k)
|
|
return torch.sigmoid(self.receptance(xr)) * kv
|
|
|
|
|
|
########################################################################################################
|
|
|
|
|
|
class MishGLU(MyModule):
|
|
def __init__(self, args, layer_id):
|
|
super().__init__()
|
|
self.args = args
|
|
self.layer_id = layer_id
|
|
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
|
|
|
with torch.no_grad():
|
|
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
|
|
|
|
x = torch.ones(1, 1, args.n_embd)
|
|
for i in range(args.n_embd):
|
|
x[0, 0, i] = i / args.n_embd
|
|
|
|
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
|
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
|
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
|
|
|
@MyFunction
|
|
def forward(self, x):
|
|
xx = self.time_shift(x)
|
|
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
|
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
|
a = self.aa(xa)
|
|
b = self.bb(xb)
|
|
return self.value(a * F.mish(b))
|
|
|
|
|
|
########################################################################################################
|
|
# The RWKV Model with our blocks
|
|
########################################################################################################
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, args, layer_id):
|
|
super().__init__()
|
|
self.args = args
|
|
self.layer_id = layer_id
|
|
|
|
self.ln1 = nn.LayerNorm(args.n_embd)
|
|
self.ln2 = nn.LayerNorm(args.n_embd)
|
|
|
|
if self.layer_id == 0:
|
|
self.ln0 = nn.LayerNorm(args.n_embd)
|
|
if args.my_pos_emb > 0:
|
|
self.pos_emb_x = nn.Parameter(
|
|
torch.zeros((1, args.my_pos_emb, args.n_embd))
|
|
)
|
|
self.pos_emb_y = nn.Parameter(
|
|
torch.zeros((args.my_pos_emb, 1, args.n_embd))
|
|
)
|
|
|
|
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
|
self.ffnPre = RWKV_ChannelMix(args, 0)
|
|
else:
|
|
if "x060" in os.environ["RWKV_MY_TESTING"]:
|
|
self.att = RWKV_Tmix_x060(args, layer_id)
|
|
else:
|
|
self.att = RWKV_TimeMix_RWKV5(args, layer_id)
|
|
|
|
if "g" in os.environ["RWKV_MY_TESTING"]:
|
|
self.ffn = MishGLU(args, layer_id)
|
|
else:
|
|
if "x060" in os.environ["RWKV_MY_TESTING"]:
|
|
self.ffn = RWKV_CMix_x060(args, layer_id)
|
|
else:
|
|
self.ffn = RWKV_ChannelMix(args, layer_id)
|
|
|
|
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
|
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
|
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
|
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
|
self.register_buffer(
|
|
"tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
|
)
|
|
|
|
if args.dropout > 0:
|
|
self.drop0 = nn.Dropout(p=args.dropout)
|
|
self.drop1 = nn.Dropout(p=args.dropout)
|
|
|
|
def forward(self, x, x_emb=None):
|
|
args = self.args
|
|
B, T, C = x.size()
|
|
if self.layer_id == 0:
|
|
x = self.ln0(x)
|
|
if args.my_pos_emb > 0:
|
|
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :]
|
|
x = x + pos_emb
|
|
|
|
if self.args.dropout == 0:
|
|
if self.layer_id == 0 and args.pre_ffn > 0:
|
|
x = x + self.ffnPre(self.ln1(x))
|
|
else:
|
|
x = x + self.att(self.ln1(x))
|
|
x = x + self.ffn(self.ln2(x))
|
|
else:
|
|
if self.layer_id == 0 and args.pre_ffn > 0:
|
|
x = self.drop0(x + self.ffnPre(self.ln1(x)))
|
|
else:
|
|
x = self.drop0(x + self.att(self.ln1(x)))
|
|
x = self.drop1(x + self.ffn(self.ln2(x)))
|
|
|
|
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
|
xx = self.tiny_ln(x)
|
|
q = self.tiny_q(xx)[:, :T, :]
|
|
k = self.tiny_k(xx)[:, :T, :]
|
|
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
|
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
|
x = x + c @ self.tiny_v(x_emb)
|
|
return x
|
|
|
|
|
|
class L2Wrap(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, loss, y):
|
|
ctx.save_for_backward(y)
|
|
return loss
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
y = ctx.saved_tensors[0]
|
|
# to encourage the logits to be close to 0
|
|
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
|
maxx, ids = torch.max(y, -1, keepdim=True)
|
|
gy = torch.zeros_like(y)
|
|
gy.scatter_(-1, ids, maxx * factor)
|
|
return (grad_output, gy)
|
|
|
|
|
|
class RWKV(pl.LightningModule):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.args = args
|
|
if not hasattr(args, "dim_att"):
|
|
args.dim_att = args.n_embd
|
|
if not hasattr(args, "dim_ffn"):
|
|
args.dim_ffn = args.n_embd * 4
|
|
if not hasattr(args, "tiny_att_layer"):
|
|
args.tiny_att_layer = -1
|
|
if not hasattr(args, "tiny_att_dim"):
|
|
args.tiny_att_dim = -1
|
|
assert args.n_embd % 32 == 0
|
|
assert args.dim_att % 32 == 0
|
|
assert args.dim_ffn % 32 == 0
|
|
|
|
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
|
|
|
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
|
|
|
self.ln_out = nn.LayerNorm(args.n_embd)
|
|
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
|
|
|
if args.head_qk > 0:
|
|
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
|
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
|
self.register_buffer(
|
|
"copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
|
)
|
|
if args.dropout > 0:
|
|
self.drop0 = nn.Dropout(p=args.dropout)
|
|
|
|
def configure_optimizers(self):
|
|
args = self.args
|
|
|
|
lr_decay = set()
|
|
lr_1x = set()
|
|
lr_2x = set()
|
|
lr_3x = set()
|
|
for n, p in self.named_parameters():
|
|
if (("_w1" in n) or ("_w2" in n)) and (args.layerwise_lr > 0):
|
|
lr_1x.add(n)
|
|
elif (("time_mix" in n) or ("time_maa" in n)) and (args.layerwise_lr > 0):
|
|
if args.my_pile_stage == 2:
|
|
lr_2x.add(n)
|
|
else:
|
|
lr_1x.add(n)
|
|
elif (("time_decay" in n) or ("time_daaaa" in n)) and (
|
|
args.layerwise_lr > 0
|
|
):
|
|
if args.my_pile_stage == 2:
|
|
lr_3x.add(n)
|
|
else:
|
|
lr_2x.add(n)
|
|
elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
|
|
if args.my_pile_stage == 2:
|
|
lr_2x.add(n)
|
|
else:
|
|
lr_1x.add(n)
|
|
elif ("time_first" in n) and (args.layerwise_lr > 0):
|
|
lr_3x.add(n)
|
|
elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
|
|
lr_decay.add(n)
|
|
else:
|
|
lr_1x.add(n)
|
|
|
|
lr_decay = sorted(list(lr_decay))
|
|
lr_1x = sorted(list(lr_1x))
|
|
lr_2x = sorted(list(lr_2x))
|
|
lr_3x = sorted(list(lr_3x))
|
|
# print('decay', lr_decay)
|
|
# print('1x', lr_1x)
|
|
# print('2x', lr_2x)
|
|
# print('3x', lr_3x)
|
|
param_dict = {n: p for n, p in self.named_parameters()}
|
|
|
|
if args.layerwise_lr > 0:
|
|
if args.my_pile_stage == 2:
|
|
optim_groups = [
|
|
{
|
|
"params": [param_dict[n] for n in lr_1x],
|
|
"weight_decay": 0.0,
|
|
"my_lr_scale": 1.0,
|
|
},
|
|
{
|
|
"params": [param_dict[n] for n in lr_2x],
|
|
"weight_decay": 0.0,
|
|
"my_lr_scale": 5.0,
|
|
}, # test: 2e-3 / args.lr_init},
|
|
{
|
|
"params": [param_dict[n] for n in lr_3x],
|
|
"weight_decay": 0.0,
|
|
"my_lr_scale": 5.0,
|
|
}, # test: 3e-3 / args.lr_init},
|
|
]
|
|
else:
|
|
optim_groups = [
|
|
{
|
|
"params": [param_dict[n] for n in lr_1x],
|
|
"weight_decay": 0.0,
|
|
"my_lr_scale": 1.0,
|
|
},
|
|
{
|
|
"params": [param_dict[n] for n in lr_2x],
|
|
"weight_decay": 0.0,
|
|
"my_lr_scale": 2.0,
|
|
},
|
|
{
|
|
"params": [param_dict[n] for n in lr_3x],
|
|
"weight_decay": 0.0,
|
|
"my_lr_scale": 3.0,
|
|
},
|
|
]
|
|
else:
|
|
optim_groups = [
|
|
{
|
|
"params": [param_dict[n] for n in lr_1x],
|
|
"weight_decay": 0.0,
|
|
"my_lr_scale": 1.0,
|
|
}
|
|
]
|
|
|
|
if args.weight_decay > 0:
|
|
optim_groups += [
|
|
{
|
|
"params": [param_dict[n] for n in lr_decay],
|
|
"weight_decay": args.weight_decay,
|
|
"my_lr_scale": 1.0,
|
|
}
|
|
]
|
|
if self.deepspeed_offload:
|
|
return DeepSpeedCPUAdam(
|
|
optim_groups,
|
|
lr=self.args.lr_init,
|
|
betas=self.args.betas,
|
|
eps=self.args.adam_eps,
|
|
bias_correction=True,
|
|
adamw_mode=True,
|
|
amsgrad=False,
|
|
)
|
|
return FusedAdam(
|
|
optim_groups,
|
|
lr=self.args.lr_init,
|
|
betas=self.args.betas,
|
|
eps=self.args.adam_eps,
|
|
bias_correction=True,
|
|
adam_w_mode=True,
|
|
amsgrad=False,
|
|
)
|
|
else:
|
|
if self.deepspeed_offload:
|
|
return DeepSpeedCPUAdam(
|
|
optim_groups,
|
|
lr=self.args.lr_init,
|
|
betas=self.args.betas,
|
|
eps=self.args.adam_eps,
|
|
bias_correction=True,
|
|
adamw_mode=False,
|
|
weight_decay=0,
|
|
amsgrad=False,
|
|
)
|
|
return FusedAdam(
|
|
optim_groups,
|
|
lr=self.args.lr_init,
|
|
betas=self.args.betas,
|
|
eps=self.args.adam_eps,
|
|
bias_correction=True,
|
|
adam_w_mode=False,
|
|
weight_decay=0,
|
|
amsgrad=False,
|
|
)
|
|
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
|
|
|
@property
|
|
def deepspeed_offload(self) -> bool:
|
|
strategy = self.trainer.strategy
|
|
if isinstance(strategy, DeepSpeedStrategy):
|
|
cfg = strategy.config["zero_optimization"]
|
|
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
|
return False
|
|
|
|
def forward(self, idx):
|
|
args = self.args
|
|
B, T = idx.size()
|
|
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
|
|
|
x = self.emb(idx)
|
|
x_emb = x
|
|
|
|
if args.dropout > 0:
|
|
x = self.drop0(x)
|
|
if args.tiny_att_dim > 0:
|
|
for block in self.blocks:
|
|
if args.grad_cp == 1:
|
|
if args.lora:
|
|
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
|
else:
|
|
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
|
else:
|
|
x = block(x, x_emb)
|
|
else:
|
|
for block in self.blocks:
|
|
if args.grad_cp == 1:
|
|
if args.lora:
|
|
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
|
else:
|
|
x = deepspeed.checkpointing.checkpoint(block, x)
|
|
else:
|
|
x = block(x)
|
|
|
|
x = self.ln_out(x)
|
|
|
|
if args.head_qk > 0:
|
|
q = self.head_q(x)[:, :T, :]
|
|
k = self.head_k(x)[:, :T, :]
|
|
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
|
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
|
|
|
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
|
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
|
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
|
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
|
|
|
x = self.head(x) + c
|
|
else:
|
|
x = self.head(x)
|
|
|
|
return x
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
args = self.args
|
|
if args.my_qa_mask != 1:
|
|
idx, targets = batch
|
|
logits = self(idx)
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
# if '0' in os.environ["RWKV_MY_TESTING"]:
|
|
# print('logits', logits)
|
|
# torch.set_printoptions(threshold=10000)
|
|
# print('idx', idx)
|
|
# exit(0)
|
|
else:
|
|
idx, targets, mask = batch
|
|
mask = mask.view(-1)
|
|
sum_mask = torch.sum(mask).item()
|
|
# if sum_mask == 0:
|
|
# return torch.tensor([0.0], requires_grad=True)
|
|
|
|
logits = self(idx)
|
|
if sum_mask == mask.shape[0]:
|
|
loss = F.cross_entropy(
|
|
logits.view(-1, logits.size(-1)), targets.view(-1)
|
|
)
|
|
# print('rank', self.global_rank, 'loss', loss.item())
|
|
else:
|
|
loss = F.cross_entropy(
|
|
logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none"
|
|
)
|
|
# loss_raw = loss
|
|
loss = torch.sum(loss * mask) / sum_mask
|
|
|
|
# torch.set_printoptions(threshold=10000)
|
|
# if True: #self.global_rank == 1:
|
|
# tmp = ''
|
|
# sss = 0
|
|
# ccc = 0
|
|
# for i in range(mask.shape[0]):
|
|
# if mask[i] > 0:
|
|
# tmp += str(idx.view(-1)[i].item()) + ','
|
|
# sss += loss_raw.view(-1)[i].float().item()
|
|
# ccc += 1
|
|
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
|
|
|
return L2Wrap.apply(loss, logits)
|
|
|
|
def training_step_end(self, batch_parts):
|
|
if pl.__version__[0] != "2":
|
|
all = self.all_gather(batch_parts)
|
|
if self.trainer.is_global_zero:
|
|
self.trainer.my_loss_all = all
|
|
|
|
def generate_init_weight(self):
|
|
print(
|
|
f"""
|
|
############################################################################
|
|
#
|
|
# Init model weight (slow for large models)...
|
|
#
|
|
############################################################################
|
|
"""
|
|
)
|
|
m = {}
|
|
for n in self.state_dict():
|
|
p = self.state_dict()[n]
|
|
shape = p.shape
|
|
|
|
gain = 1.0
|
|
scale = 1.0
|
|
if (
|
|
"ln_" in n
|
|
or ".ln" in n
|
|
or "time_" in n
|
|
or "_mask" in n
|
|
or "pos_emb" in n
|
|
or ".mask." in n
|
|
):
|
|
if "ln_x.weight" in n:
|
|
layer_scale = (1 + int(n.split(".")[1])) / self.args.n_layer
|
|
m[n] = (p * 0.0) + (layer_scale**0.7)
|
|
else:
|
|
m[n] = p
|
|
else:
|
|
if n == "emb.weight":
|
|
scale = -1 * self.args.lr_init
|
|
else:
|
|
if shape[0] > shape[1]:
|
|
gain = math.sqrt(shape[0] / shape[1])
|
|
|
|
zero = [
|
|
".att.output.",
|
|
".ffn.value.",
|
|
".ffn.receptance.",
|
|
".ffnPre.value.",
|
|
".ffnPre.receptance.",
|
|
"head_q.",
|
|
".oo.",
|
|
".rr.",
|
|
]
|
|
|
|
for kk in zero:
|
|
if kk in n:
|
|
scale = 0
|
|
if n == "head.weight":
|
|
scale = 0.5
|
|
if "head_k." in n:
|
|
scale = 0.1
|
|
if "head_q." in n:
|
|
scale = 0
|
|
|
|
print(
|
|
f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}"
|
|
)
|
|
|
|
if self.args.accelerator.upper() == "GPU":
|
|
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
|
else:
|
|
m[n] = torch.empty((shape[0], shape[1]))
|
|
|
|
if scale == 0:
|
|
nn.init.zeros_(m[n])
|
|
elif scale < 0:
|
|
nn.init.uniform_(m[n], a=scale, b=-scale)
|
|
else:
|
|
nn.init.orthogonal_(m[n], gain=gain * scale)
|
|
|
|
m[n] = m[n].cpu()
|
|
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
|
m[n] = m[n].half()
|
|
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
|
m[n] = m[n].bfloat16()
|
|
|
|
# if n == "emb.weight":
|
|
# print(m[n])
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
return m
|