1954 lines
76 KiB
Python
Vendored
1954 lines
76 KiB
Python
Vendored
########################################################################################################
|
||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||
########################################################################################################
|
||
import functools
|
||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||
|
||
import os, math, gc, importlib
|
||
import torch
|
||
|
||
# torch._C._jit_set_profiling_executor(True)
|
||
# torch._C._jit_set_profiling_mode(True)
|
||
import torch.nn as nn
|
||
from torch.nn import functional as F
|
||
import pytorch_lightning as pl
|
||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||
|
||
if importlib.util.find_spec("deepspeed"):
|
||
import deepspeed
|
||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
||
from torch._lowrank import svd_lowrank
|
||
import bitsandbytes as bnb
|
||
from .infctx_module import *
|
||
from einops import rearrange
|
||
from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
|
||
|
||
LORA_CONFIG = {
|
||
"r": 0,
|
||
"alpha": 0,
|
||
"dropout": 0,
|
||
"parts": {"att", "ln", "time", "ffn"},
|
||
}
|
||
|
||
|
||
class LoraLinear(nn.Module):
|
||
|
||
def __init__(self, in_features: int, out_features: int, bias: bool):
|
||
super().__init__()
|
||
|
||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||
assert bias == False, "Biased LoraLinear not supported"
|
||
|
||
r, alpha, dropout = (
|
||
LORA_CONFIG["r"],
|
||
LORA_CONFIG["alpha"],
|
||
LORA_CONFIG["dropout"],
|
||
)
|
||
self.lora_A = nn.Parameter(torch.empty(r, in_features))
|
||
self.lora_B = nn.Parameter(torch.empty(out_features, r))
|
||
self.lora_dropout = nn.Dropout(dropout)
|
||
self.scaling = alpha / r
|
||
self.r = r
|
||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||
nn.init.zeros_(self.lora_B)
|
||
self.pissa = False
|
||
self.is_quant = False
|
||
|
||
def pissa_init(self, svd_niter):
|
||
|
||
self.pissa = True
|
||
Ur, Sr, Vr = svd_lowrank(self.weight.data, self.r, niter=svd_niter)
|
||
Vhr = Vr.t()
|
||
lora_A = torch.diag(torch.sqrt(Sr)) @ Vhr
|
||
lora_B = Ur @ torch.diag(torch.sqrt(Sr))
|
||
self.lora_A.data = lora_A
|
||
self.lora_B.data = lora_B
|
||
self.weight.data = self.weight.data - lora_B @ lora_A
|
||
|
||
def quant(self, quant_type):
|
||
self.is_quant = True
|
||
self.quant_type = quant_type
|
||
if self.quant_type == "4bit":
|
||
self.weight.data, self.qstate = bnb.functional.quantize_4bit(
|
||
(self.weight.data).to("cuda")
|
||
)
|
||
elif self.quant_type == "nf4":
|
||
self.weight.data, self.qstate = bnb.functional.quantize_nf4(
|
||
(self.weight.data).to("cuda")
|
||
)
|
||
elif self.quant_type == "fp4":
|
||
self.weight.data, self.qstate = bnb.functional.quantize_fp4(
|
||
(self.weight.data).to("cuda")
|
||
)
|
||
|
||
def forward(self, x):
|
||
|
||
if self.is_quant:
|
||
if self.quant_type == "4bit":
|
||
if self.pissa:
|
||
return F.linear(
|
||
x,
|
||
bnb.functional.dequantize_4bit(
|
||
self.weight.data, quant_state=self.qstate
|
||
).to(torch.bfloat16),
|
||
) + F.linear(F.linear(x, self.lora_A), self.lora_B)
|
||
return F.linear(
|
||
x,
|
||
bnb.functional.dequantize_4bit(
|
||
self.weight.data, quant_state=self.qstate
|
||
).to(torch.bfloat16),
|
||
) + self.scaling * F.linear(
|
||
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
|
||
)
|
||
elif self.quant_type == "nf4":
|
||
if self.pissa:
|
||
return F.linear(
|
||
x,
|
||
bnb.functional.dequantize_nf4(
|
||
self.weight.data, quant_state=self.qstate
|
||
).to(torch.bfloat16),
|
||
) + F.linear(F.linear(x, self.lora_A), self.lora_B)
|
||
return F.linear(
|
||
x,
|
||
bnb.functional.dequantize_nf4(
|
||
self.weight.data, quant_state=self.qstate
|
||
).to(torch.bfloat16),
|
||
) + self.scaling * F.linear(
|
||
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
|
||
)
|
||
elif self.quant_type == "fp4":
|
||
if self.pissa:
|
||
return F.linear(
|
||
x,
|
||
bnb.functional.dequantize_fp4(
|
||
self.weight.data, quant_state=self.qstate
|
||
).to(torch.bfloat16),
|
||
) + F.linear(F.linear(x, self.lora_A), self.lora_B)
|
||
return F.linear(
|
||
x,
|
||
bnb.functional.dequantize_fp4(
|
||
self.weight.data, quant_state=self.qstate
|
||
).to(torch.bfloat16),
|
||
) + self.scaling * F.linear(
|
||
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
|
||
)
|
||
|
||
if self.pissa:
|
||
return F.linear(x, self.weight) + F.linear(
|
||
F.linear(x, self.lora_A), self.lora_B
|
||
)
|
||
return F.linear(x, self.weight) + self.scaling * F.linear(
|
||
F.linear(self.lora_dropout(x), self.lora_A), self.lora_B
|
||
)
|
||
|
||
|
||
@functools.wraps(LoraLinear)
|
||
def make_linear_att(*args, **kwargs):
|
||
if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||
return LoraLinear(*args, **kwargs)
|
||
else:
|
||
return nn.Linear(*args, **kwargs)
|
||
|
||
|
||
@functools.wraps(LoraLinear)
|
||
def make_linear_ffn(*args, **kwargs):
|
||
if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||
return LoraLinear(*args, **kwargs)
|
||
else:
|
||
return nn.Linear(*args, **kwargs)
|
||
|
||
|
||
try:
|
||
print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"])
|
||
except:
|
||
os.environ["RWKV_MY_TESTING"] = ""
|
||
|
||
|
||
def __nop(ob):
|
||
return ob
|
||
|
||
|
||
MyModule = nn.Module
|
||
MyFunction = __nop
|
||
if os.environ["RWKV_JIT_ON"] == "1":
|
||
MyModule = torch.jit.ScriptModule
|
||
MyFunction = torch.jit.script_method
|
||
|
||
|
||
########################################################################################################
|
||
# CUDA Kernel
|
||
########################################################################################################
|
||
if os.environ["WKV"] == "fla":
|
||
if "x060" in os.environ["RWKV_MY_TESTING"]:
|
||
if (
|
||
os.environ["RWKV_TRAIN_TYPE"] == "infctx"
|
||
and "x060" in os.environ["RWKV_MY_TESTING"]
|
||
):
|
||
|
||
def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s):
|
||
r = rearrange(r, "b l (h d) -> b h l d", h=H)
|
||
k = rearrange(k, "b l (h d) -> b h l d", h=H)
|
||
v = rearrange(v, "b l (h d) -> b h l d", h=H)
|
||
w = rearrange(-torch.exp(w), "b l (h d) -> b h l d", h=H)
|
||
o, state = chunk_rwkv6(
|
||
r, k, v, w, u=u, initial_state=s, output_final_state=True
|
||
)
|
||
x = rearrange(o, "b h l d -> b l (h d)")
|
||
return x, state
|
||
|
||
elif os.environ["RWKV_TRAIN_TYPE"] == "states":
|
||
|
||
def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s):
|
||
r = rearrange(r, "b l (h d) -> b h l d", h=H)
|
||
k = rearrange(k, "b l (h d) -> b h l d", h=H)
|
||
v = rearrange(v, "b l (h d) -> b h l d", h=H)
|
||
w = rearrange(-torch.exp(w), "b l (h d) -> b h l d", h=H)
|
||
o, _ = chunk_rwkv6(
|
||
r, k, v, w, u=u, initial_state=s, output_final_state=False
|
||
)
|
||
x = rearrange(o, "b h l d -> b l (h d)")
|
||
return x
|
||
|
||
else:
|
||
|
||
def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
|
||
r = rearrange(r, "b l (h d) -> b h l d", h=H)
|
||
k = rearrange(k, "b l (h d) -> b h l d", h=H)
|
||
v = rearrange(v, "b l (h d) -> b h l d", h=H)
|
||
w = rearrange(-torch.exp(w), "b l (h d) -> b h l d", h=H)
|
||
o, _ = chunk_rwkv6(
|
||
r, k, v, w, u=u, initial_state=None, output_final_state=False
|
||
)
|
||
x = rearrange(o, "b h l d -> b l (h d)")
|
||
return x
|
||
|
||
else:
|
||
from torch.utils.cpp_extension import load
|
||
|
||
HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"])
|
||
|
||
if "x060" in os.environ["RWKV_MY_TESTING"]:
|
||
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
|
||
wkv6state_cuda = load(
|
||
name="wkv6infctx",
|
||
sources=[
|
||
"finetune/lora/v6/cuda/wkv6infctx_op.cpp",
|
||
f"finetune/lora/v6/cuda/wkv6infctx_cuda.cu",
|
||
],
|
||
verbose=True,
|
||
extra_cuda_cflags=[
|
||
"-res-usage",
|
||
"--use_fast_math",
|
||
"-O3",
|
||
"-Xptxas -O3",
|
||
"--extra-device-vectorization",
|
||
f"-D_N_={HEAD_SIZE}",
|
||
f"-D_T_={int(os.environ['RWKV_CTXLEN'])}",
|
||
],
|
||
)
|
||
|
||
class WKV_6STATE(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, B, T, C, H, r, k, v, w, u, s):
|
||
with torch.no_grad():
|
||
assert r.dtype == torch.bfloat16
|
||
assert k.dtype == torch.bfloat16
|
||
assert v.dtype == torch.bfloat16
|
||
assert w.dtype == torch.bfloat16
|
||
assert u.dtype == torch.bfloat16
|
||
assert s.dtype == torch.bfloat16
|
||
assert HEAD_SIZE == C // H
|
||
ctx.B = B
|
||
ctx.T = T
|
||
ctx.C = C
|
||
ctx.H = H
|
||
assert r.is_contiguous()
|
||
assert k.is_contiguous()
|
||
assert v.is_contiguous()
|
||
assert w.is_contiguous()
|
||
assert u.is_contiguous()
|
||
assert s.is_contiguous()
|
||
ctx.save_for_backward(r, k, v, w, u, s)
|
||
y = torch.empty(
|
||
(B, T, C),
|
||
device=r.device,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
wkv6state_cuda.forward(B, T, C, H, r, k, v, w, u, s, y)
|
||
return y
|
||
|
||
@staticmethod
|
||
def backward(ctx, gy):
|
||
with torch.no_grad():
|
||
assert gy.dtype == torch.bfloat16
|
||
B = ctx.B
|
||
T = ctx.T
|
||
C = ctx.C
|
||
H = ctx.H
|
||
assert gy.is_contiguous()
|
||
r, k, v, w, u, s = ctx.saved_tensors
|
||
gr = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gk = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gv = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gw = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gu = torch.empty(
|
||
(B, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gs = torch.empty(
|
||
(B, H, C // H, C // H),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
wkv6state_cuda.backward(
|
||
B, T, C, H, r, k, v, w, u, s, gy, gr, gk, gv, gw, gu, gs
|
||
)
|
||
gu = torch.sum(gu, 0).view(H, C // H)
|
||
gs = torch.sum(gs, 0).view(H, C // H, C // H)
|
||
return (None, None, None, None, gr, gk, gv, gw, gu, gs)
|
||
|
||
def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s):
|
||
x = WKV_6STATE.apply(B, T, C, H, r, k, v, w, u, s)
|
||
return x, s
|
||
|
||
elif os.environ["RWKV_TRAIN_TYPE"] == "states":
|
||
wkv6state_cuda = load(
|
||
name="wkv6state",
|
||
sources=[
|
||
"finetune/lora/v6/cuda/wkv6state_op.cpp",
|
||
f"finetune/lora/v6/cuda/wkv6state_cuda.cu",
|
||
],
|
||
verbose=True,
|
||
extra_cuda_cflags=[
|
||
"-res-usage",
|
||
"--use_fast_math",
|
||
"-O3",
|
||
"-Xptxas -O3",
|
||
"--extra-device-vectorization",
|
||
f"-D_N_={HEAD_SIZE}",
|
||
f"-D_T_={int(os.environ['RWKV_CTXLEN'])}",
|
||
],
|
||
)
|
||
|
||
class WKV_6STATE(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, B, T, C, H, r, k, v, w, u, s):
|
||
with torch.no_grad():
|
||
assert r.dtype == torch.bfloat16
|
||
assert k.dtype == torch.bfloat16
|
||
assert v.dtype == torch.bfloat16
|
||
assert w.dtype == torch.bfloat16
|
||
assert u.dtype == torch.bfloat16
|
||
assert s.dtype == torch.bfloat16
|
||
assert HEAD_SIZE == C // H
|
||
ctx.B = B
|
||
ctx.T = T
|
||
ctx.C = C
|
||
ctx.H = H
|
||
assert r.is_contiguous()
|
||
assert k.is_contiguous()
|
||
assert v.is_contiguous()
|
||
assert w.is_contiguous()
|
||
assert u.is_contiguous()
|
||
assert s.is_contiguous()
|
||
ctx.save_for_backward(r, k, v, w, u, s)
|
||
y = torch.empty(
|
||
(B, T, C),
|
||
device=r.device,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
wkv6state_cuda.forward(B, T, C, H, r, k, v, w, u, s, y)
|
||
return y
|
||
|
||
@staticmethod
|
||
def backward(ctx, gy):
|
||
with torch.no_grad():
|
||
assert gy.dtype == torch.bfloat16
|
||
B = ctx.B
|
||
T = ctx.T
|
||
C = ctx.C
|
||
H = ctx.H
|
||
assert gy.is_contiguous()
|
||
r, k, v, w, u, s = ctx.saved_tensors
|
||
gr = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gk = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gv = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gw = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gu = torch.empty(
|
||
(B, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gs = torch.empty(
|
||
(B, H, C // H, C // H),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
wkv6state_cuda.backward(
|
||
B, T, C, H, r, k, v, w, u, s, gy, gr, gk, gv, gw, gu, gs
|
||
)
|
||
gu = torch.sum(gu, 0).view(H, C // H)
|
||
gs = torch.sum(gs, 0).view(H, C // H, C // H)
|
||
return (None, None, None, None, gr, gk, gv, gw, gu, gs)
|
||
|
||
def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s):
|
||
return WKV_6STATE.apply(B, T, C, H, r, k, v, w, u, s)
|
||
|
||
else:
|
||
wkv6_cuda = load(
|
||
name="wkv6",
|
||
sources=[
|
||
"finetune/lora/v6/cuda/wkv6_op.cpp",
|
||
f"finetune/lora/v6/cuda/wkv6_cuda.cu",
|
||
],
|
||
verbose=True,
|
||
extra_cuda_cflags=[
|
||
"-res-usage",
|
||
"--use_fast_math",
|
||
"-O3",
|
||
"-Xptxas -O3",
|
||
"--extra-device-vectorization",
|
||
f"-D_N_={HEAD_SIZE}",
|
||
f"-D_T_={int(os.environ['RWKV_CTXLEN'])}",
|
||
],
|
||
)
|
||
|
||
class WKV_6(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, B, T, C, H, r, k, v, w, u):
|
||
with torch.no_grad():
|
||
assert r.dtype == torch.bfloat16
|
||
assert k.dtype == torch.bfloat16
|
||
assert v.dtype == torch.bfloat16
|
||
assert w.dtype == torch.bfloat16
|
||
assert u.dtype == torch.bfloat16
|
||
assert HEAD_SIZE == C // H
|
||
ctx.B = B
|
||
ctx.T = T
|
||
ctx.C = C
|
||
ctx.H = H
|
||
assert r.is_contiguous()
|
||
assert k.is_contiguous()
|
||
assert v.is_contiguous()
|
||
assert w.is_contiguous()
|
||
assert u.is_contiguous()
|
||
ew = (-torch.exp(w.float())).contiguous()
|
||
ctx.save_for_backward(r, k, v, ew, u)
|
||
y = torch.empty(
|
||
(B, T, C),
|
||
device=r.device,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y)
|
||
return y
|
||
|
||
@staticmethod
|
||
def backward(ctx, gy):
|
||
with torch.no_grad():
|
||
assert gy.dtype == torch.bfloat16
|
||
B = ctx.B
|
||
T = ctx.T
|
||
C = ctx.C
|
||
H = ctx.H
|
||
assert gy.is_contiguous()
|
||
r, k, v, ew, u = ctx.saved_tensors
|
||
gr = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gk = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gv = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gw = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
gu = torch.empty(
|
||
(B, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-100, 100)
|
||
wkv6_cuda.backward(
|
||
B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu
|
||
)
|
||
gu = torch.sum(gu, 0).view(H, C // H)
|
||
return (None, None, None, None, gr, gk, gv, gw, gu)
|
||
|
||
def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
|
||
return WKV_6.apply(B, T, C, H, r, k, v, w, u)
|
||
|
||
else:
|
||
wkv5_cuda = load(
|
||
name="wkv5",
|
||
sources=[
|
||
"finetune/lora/v6/cuda/wkv5_op.cpp",
|
||
f"finetune/lora/v6/cuda/wkv5_cuda.cu",
|
||
],
|
||
verbose=True,
|
||
extra_cuda_cflags=[
|
||
"-res-usage",
|
||
"--use_fast_math",
|
||
"-O3",
|
||
"-Xptxas -O3",
|
||
"--extra-device-vectorization",
|
||
f"-D_N_={HEAD_SIZE}",
|
||
],
|
||
)
|
||
|
||
class WKV_5(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, B, T, C, H, r, k, v, w, u):
|
||
with torch.no_grad():
|
||
assert r.dtype == torch.bfloat16
|
||
assert k.dtype == torch.bfloat16
|
||
assert v.dtype == torch.bfloat16
|
||
assert w.dtype == torch.bfloat16
|
||
assert u.dtype == torch.bfloat16
|
||
assert HEAD_SIZE == C // H
|
||
ctx.B = B
|
||
ctx.T = T
|
||
ctx.C = C
|
||
ctx.H = H
|
||
assert r.is_contiguous()
|
||
assert k.is_contiguous()
|
||
assert v.is_contiguous()
|
||
assert w.is_contiguous()
|
||
assert u.is_contiguous()
|
||
ew = (-torch.exp(w.float())).contiguous()
|
||
eew = (torch.exp(ew)).contiguous()
|
||
ctx.save_for_backward(r, k, v, eew, ew, u)
|
||
y = torch.empty(
|
||
(B, T, C),
|
||
device=r.device,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-1, 1)
|
||
wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
|
||
return y
|
||
|
||
@staticmethod
|
||
def backward(ctx, gy):
|
||
with torch.no_grad():
|
||
assert gy.dtype == torch.bfloat16
|
||
B = ctx.B
|
||
T = ctx.T
|
||
C = ctx.C
|
||
H = ctx.H
|
||
assert gy.is_contiguous()
|
||
r, k, v, eew, ew, u = ctx.saved_tensors
|
||
gr = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-1, 1)
|
||
gk = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-1, 1)
|
||
gv = torch.empty(
|
||
(B, T, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-1, 1)
|
||
gw = torch.empty(
|
||
(B, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-1, 1)
|
||
gu = torch.empty(
|
||
(B, C),
|
||
device=gy.device,
|
||
requires_grad=False,
|
||
dtype=torch.bfloat16,
|
||
memory_format=torch.contiguous_format,
|
||
) # .uniform_(-1, 1)
|
||
wkv5_cuda.backward(
|
||
B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu
|
||
)
|
||
gw = torch.sum(gw, 0).view(H, C // H)
|
||
gu = torch.sum(gu, 0).view(H, C // H)
|
||
return (None, None, None, None, gr, gk, gv, gw, gu)
|
||
|
||
def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
|
||
return WKV_5.apply(B, T, C, H, r, k, v, w, u)
|
||
|
||
|
||
########################################################################################################
|
||
|
||
|
||
class RWKV_TimeMix_RWKV5(MyModule):
|
||
def __init__(self, args, layer_id):
|
||
super().__init__()
|
||
self.args = args
|
||
self.layer_id = layer_id
|
||
|
||
self.head_size = args.head_size_a
|
||
assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
|
||
self.n_head = args.dim_att // self.head_size
|
||
assert args.dim_att % self.n_head == 0
|
||
self.head_size_divisor = args.head_size_divisor
|
||
|
||
with torch.no_grad():
|
||
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||
ddd = torch.ones(1, 1, args.n_embd)
|
||
for i in range(args.n_embd):
|
||
ddd[0, 0, i] = i / args.n_embd
|
||
|
||
# fancy time_mix
|
||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_mix_v = nn.Parameter(
|
||
torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||
)
|
||
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||
self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||
|
||
# fancy time_decay
|
||
decay_speed = torch.ones(args.dim_att)
|
||
for n in range(args.dim_att):
|
||
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
|
||
0.7 + 1.3 * ratio_0_to_1
|
||
)
|
||
self.time_decay = nn.Parameter(
|
||
decay_speed.reshape(self.n_head, self.head_size)
|
||
)
|
||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||
|
||
tmp = torch.zeros(args.dim_att)
|
||
for n in range(args.dim_att):
|
||
zigzag = ((n + 1) % 3 - 1) * 0.1
|
||
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
|
||
|
||
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
|
||
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||
|
||
self.output = make_linear_att(args.dim_att, args.n_embd, bias=False)
|
||
self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||
self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
|
||
|
||
@MyFunction
|
||
def jit_func(self, x):
|
||
B, T, C = x.size()
|
||
|
||
xx = self.time_shift(
|
||
x
|
||
) # Mix x with the previous timestep to produce xk, xv, xr
|
||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||
xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
|
||
|
||
r = self.receptance(xr)
|
||
k = self.key(xk)
|
||
v = self.value(xv)
|
||
g = F.silu(self.gate(xg))
|
||
|
||
return r, k, v, g
|
||
|
||
@MyFunction
|
||
def jit_func_2(self, x, g):
|
||
B, T, C = x.size()
|
||
x = x.view(B * T, C)
|
||
|
||
x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
|
||
x = self.output(x * g)
|
||
return x
|
||
|
||
def forward(self, x):
|
||
B, T, C = x.size()
|
||
H = self.n_head
|
||
|
||
r, k, v, g = self.jit_func(x)
|
||
|
||
x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
|
||
|
||
return self.jit_func_2(x, g)
|
||
|
||
|
||
class RWKV_Tmix_x060(MyModule):
|
||
def __init__(self, args, layer_id):
|
||
super().__init__()
|
||
self.args = args
|
||
self.layer_id = layer_id
|
||
|
||
self.head_size = args.head_size_a
|
||
self.n_head = args.dim_att // self.head_size
|
||
assert args.dim_att % self.n_head == 0
|
||
|
||
with torch.no_grad():
|
||
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||
ddd = torch.ones(1, 1, args.n_embd)
|
||
for i in range(args.n_embd):
|
||
ddd[0, 0, i] = i / args.n_embd
|
||
|
||
# fancy time_mix
|
||
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_v = nn.Parameter(
|
||
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||
)
|
||
self.time_maa_r = nn.Parameter(
|
||
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
|
||
)
|
||
self.time_maa_g = nn.Parameter(
|
||
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
|
||
)
|
||
|
||
TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g
|
||
if args.n_embd == 4096:
|
||
TIME_MIX_EXTRA_DIM = TIME_MIX_EXTRA_DIM * 2
|
||
self.time_maa_w1 = nn.Parameter(
|
||
torch.zeros(args.n_embd, TIME_MIX_EXTRA_DIM * 5).uniform_(-1e-4, 1e-4)
|
||
)
|
||
self.time_maa_w2 = nn.Parameter(
|
||
torch.zeros(5, TIME_MIX_EXTRA_DIM, args.n_embd).uniform_(-1e-4, 1e-4)
|
||
)
|
||
|
||
# fancy time_decay
|
||
decay_speed = torch.ones(args.dim_att)
|
||
for n in range(args.dim_att):
|
||
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
|
||
0.7 + 1.3 * ratio_0_to_1
|
||
)
|
||
self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
|
||
|
||
TIME_DECAY_EXTRA_DIM = 64
|
||
if args.n_embd == 4096:
|
||
TIME_DECAY_EXTRA_DIM = TIME_DECAY_EXTRA_DIM * 2
|
||
self.time_decay_w1 = nn.Parameter(
|
||
torch.zeros(args.n_embd, TIME_DECAY_EXTRA_DIM).uniform_(-1e-4, 1e-4)
|
||
)
|
||
self.time_decay_w2 = nn.Parameter(
|
||
torch.zeros(TIME_DECAY_EXTRA_DIM, args.dim_att).uniform_(-1e-4, 1e-4)
|
||
)
|
||
|
||
tmp = torch.zeros(args.dim_att)
|
||
for n in range(args.dim_att):
|
||
zigzag = ((n + 1) % 3 - 1) * 0.1
|
||
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
|
||
|
||
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
|
||
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||
self.output = make_linear_att(args.dim_att, args.n_embd, bias=False)
|
||
self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||
self.ln_x = nn.GroupNorm(
|
||
self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
|
||
)
|
||
|
||
@MyFunction
|
||
def jit_func(self, x):
|
||
B, T, C = x.size()
|
||
|
||
xx = self.time_shift(x) - x
|
||
|
||
xxx = x + xx * self.time_maa_x
|
||
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
|
||
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
|
||
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
|
||
|
||
xw = x + xx * (self.time_maa_w + mw)
|
||
xk = x + xx * (self.time_maa_k + mk)
|
||
xv = x + xx * (self.time_maa_v + mv)
|
||
xr = x + xx * (self.time_maa_r + mr)
|
||
xg = x + xx * (self.time_maa_g + mg)
|
||
|
||
r = self.receptance(xr)
|
||
k = self.key(xk)
|
||
v = self.value(xv)
|
||
g = F.silu(self.gate(xg))
|
||
|
||
ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
|
||
w = self.time_decay + ww
|
||
|
||
return r, k, v, g, w
|
||
|
||
@MyFunction
|
||
def jit_func_2(self, x, g):
|
||
B, T, C = x.size()
|
||
x = x.view(B * T, C)
|
||
|
||
x = self.ln_x(x).view(B, T, C)
|
||
x = self.output(x * g)
|
||
return x
|
||
|
||
def forward(self, x):
|
||
B, T, C = x.size()
|
||
H = self.n_head
|
||
|
||
r, k, v, g, w = self.jit_func(x)
|
||
x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
|
||
|
||
return self.jit_func_2(x, g)
|
||
|
||
|
||
########################################################################################################
|
||
|
||
|
||
class RWKV_Tmix_x060_state(MyModule):
|
||
def __init__(self, args, layer_id):
|
||
super().__init__()
|
||
self.args = args
|
||
self.layer_id = layer_id
|
||
|
||
self.head_size = args.head_size_a
|
||
self.n_head = args.dim_att // self.head_size
|
||
assert args.dim_att % self.n_head == 0
|
||
|
||
with torch.no_grad():
|
||
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||
ddd = torch.ones(1, 1, args.n_embd)
|
||
for i in range(args.n_embd):
|
||
ddd[0, 0, i] = i / args.n_embd
|
||
|
||
# fancy time_mix
|
||
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_v = nn.Parameter(
|
||
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||
)
|
||
self.time_maa_r = nn.Parameter(
|
||
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
|
||
)
|
||
self.time_maa_g = nn.Parameter(
|
||
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
|
||
)
|
||
|
||
D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
|
||
if args.n_embd == 4096:
|
||
D_MIX_LORA = D_MIX_LORA * 2
|
||
self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
|
||
self.time_maa_w2 = nn.Parameter(
|
||
torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
|
||
)
|
||
|
||
# fancy time_decay
|
||
decay_speed = torch.ones(args.dim_att)
|
||
for n in range(args.dim_att):
|
||
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
|
||
0.7 + 1.3 * ratio_0_to_1
|
||
)
|
||
self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
|
||
|
||
D_DECAY_LORA = 64
|
||
if args.n_embd == 4096:
|
||
D_DECAY_LORA = D_DECAY_LORA * 2
|
||
self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
|
||
self.time_decay_w2 = nn.Parameter(
|
||
torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
|
||
)
|
||
|
||
tmp = torch.zeros(args.dim_att)
|
||
for n in range(args.dim_att):
|
||
zigzag = ((n + 1) % 3 - 1) * 0.1
|
||
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
|
||
|
||
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
|
||
self.time_state = nn.Parameter(
|
||
torch.zeros(self.n_head, self.head_size, self.head_size)
|
||
)
|
||
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
||
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
||
|
||
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
||
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
||
self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
||
self.ln_x = nn.GroupNorm(
|
||
self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
|
||
)
|
||
|
||
@MyFunction
|
||
def jit_func(self, x):
|
||
B, T, C = x.size()
|
||
|
||
xx = self.time_shift(x) - x
|
||
|
||
xxx = x + xx * self.time_maa_x
|
||
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
|
||
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
|
||
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
|
||
|
||
xw = x + xx * (self.time_maa_w + mw)
|
||
xk = x + xx * (self.time_maa_k + mk)
|
||
xv = x + xx * (self.time_maa_v + mv)
|
||
xr = x + xx * (self.time_maa_r + mr)
|
||
xg = x + xx * (self.time_maa_g + mg)
|
||
|
||
r = self.receptance(xr)
|
||
k = self.key(xk)
|
||
v = self.value(xv)
|
||
g = F.silu(self.gate(xg))
|
||
|
||
ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
|
||
w = self.time_decay + ww
|
||
|
||
return r, k, v, g, w
|
||
|
||
@MyFunction
|
||
def jit_func_2(self, x, g):
|
||
B, T, C = x.size()
|
||
x = x.view(B * T, C)
|
||
|
||
x = self.ln_x(x).view(B, T, C)
|
||
x = self.output(x * g)
|
||
return x
|
||
|
||
def forward(self, x):
|
||
B, T, C = x.size()
|
||
H = self.n_head
|
||
|
||
r, k, v, g, w = self.jit_func(x)
|
||
x = RUN_CUDA_RWKV6_STATE(
|
||
B, T, C, H, r, k, v, w, u=self.time_faaaa, s=self.time_state
|
||
)
|
||
|
||
return self.jit_func_2(x, g)
|
||
|
||
|
||
########################################################################################################
|
||
|
||
|
||
class RWKV_ChannelMix(MyModule):
|
||
def __init__(self, args, layer_id):
|
||
super().__init__()
|
||
self.args = args
|
||
self.layer_id = layer_id
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||
|
||
with torch.no_grad(): # fancy init of time_mix
|
||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||
ddd = torch.ones(1, 1, args.n_embd)
|
||
for i in range(args.n_embd):
|
||
ddd[0, 0, i] = i / args.n_embd
|
||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||
|
||
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
|
||
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
|
||
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
|
||
|
||
@MyFunction
|
||
def forward(self, x):
|
||
xx = self.time_shift(x)
|
||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||
k = self.key(xk)
|
||
k = torch.relu(k) ** 2
|
||
kv = self.value(k)
|
||
return torch.sigmoid(self.receptance(xr)) * kv
|
||
|
||
|
||
class RWKV_CMix_x060(MyModule):
|
||
def __init__(self, args, layer_id):
|
||
super().__init__()
|
||
self.args = args
|
||
self.layer_id = layer_id
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||
|
||
with torch.no_grad(): # fancy init of time_mix
|
||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||
ddd = torch.ones(1, 1, args.n_embd)
|
||
for i in range(args.n_embd):
|
||
ddd[0, 0, i] = i / args.n_embd
|
||
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
|
||
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
|
||
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
|
||
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
|
||
|
||
@MyFunction
|
||
def forward(self, x):
|
||
xx = self.time_shift(x) - x
|
||
xk = x + xx * self.time_maa_k
|
||
xr = x + xx * self.time_maa_r
|
||
|
||
k = self.key(xk)
|
||
k = torch.relu(k) ** 2
|
||
kv = self.value(k)
|
||
return torch.sigmoid(self.receptance(xr)) * kv
|
||
|
||
|
||
########################################################################################################
|
||
|
||
|
||
class MishGLU(MyModule):
|
||
def __init__(self, args, layer_id):
|
||
super().__init__()
|
||
self.args = args
|
||
self.layer_id = layer_id
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||
|
||
with torch.no_grad():
|
||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
|
||
|
||
x = torch.ones(1, 1, args.n_embd)
|
||
for i in range(args.n_embd):
|
||
x[0, 0, i] = i / args.n_embd
|
||
|
||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
||
|
||
@MyFunction
|
||
def forward(self, x):
|
||
xx = self.time_shift(x)
|
||
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||
a = self.aa(xa)
|
||
b = self.bb(xb)
|
||
return self.value(a * F.mish(b))
|
||
|
||
|
||
########################################################################################################
|
||
|
||
|
||
class RWKV_Tmix_x060_infctx(MyModule):
|
||
def __init__(self, args, layer_id):
|
||
super().__init__()
|
||
self.args = args
|
||
self.layer_id = layer_id
|
||
|
||
self.head_size = args.head_size_a
|
||
self.n_head = args.dim_att // self.head_size
|
||
assert args.dim_att % self.n_head == 0
|
||
|
||
with torch.no_grad():
|
||
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||
ddd = torch.ones(1, 1, args.n_embd)
|
||
for i in range(args.n_embd):
|
||
ddd[0, 0, i] = i / args.n_embd
|
||
|
||
# fancy time_mix
|
||
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_v = nn.Parameter(
|
||
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||
)
|
||
self.time_maa_r = nn.Parameter(
|
||
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
|
||
)
|
||
self.time_maa_g = nn.Parameter(
|
||
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
|
||
)
|
||
|
||
D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
|
||
if args.n_embd == 4096:
|
||
D_MIX_LORA = D_MIX_LORA * 2
|
||
self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
|
||
self.time_maa_w2 = nn.Parameter(
|
||
torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
|
||
)
|
||
|
||
# fancy time_decay
|
||
decay_speed = torch.ones(args.dim_att)
|
||
for n in range(args.dim_att):
|
||
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (
|
||
0.7 + 1.3 * ratio_0_to_1
|
||
)
|
||
self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
|
||
|
||
D_DECAY_LORA = 64
|
||
if args.n_embd == 4096:
|
||
D_DECAY_LORA = D_DECAY_LORA * 2
|
||
self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
|
||
self.time_decay_w2 = nn.Parameter(
|
||
torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
|
||
)
|
||
|
||
tmp = torch.zeros(args.dim_att)
|
||
for n in range(args.dim_att):
|
||
zigzag = ((n + 1) % 3 - 1) * 0.1
|
||
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
|
||
|
||
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
|
||
# self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
|
||
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
||
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
||
|
||
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
||
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
||
self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
|
||
self.ln_x = nn.GroupNorm(
|
||
self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
|
||
)
|
||
|
||
@MyFunction
|
||
def jit_func(self, x, shift_state):
|
||
B, T, C = x.size()
|
||
xx = torch.concat((shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
|
||
|
||
xxx = x + xx * self.time_maa_x
|
||
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
|
||
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
|
||
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
|
||
|
||
xw = x + xx * (self.time_maa_w + mw)
|
||
xk = x + xx * (self.time_maa_k + mk)
|
||
xv = x + xx * (self.time_maa_v + mv)
|
||
xr = x + xx * (self.time_maa_r + mr)
|
||
xg = x + xx * (self.time_maa_g + mg)
|
||
|
||
r = self.receptance(xr)
|
||
k = self.key(xk)
|
||
v = self.value(xv)
|
||
g = F.silu(self.gate(xg))
|
||
|
||
ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
|
||
w = self.time_decay + ww
|
||
|
||
return r, k, v, g, w, x[:, -1]
|
||
|
||
@MyFunction
|
||
def jit_func_2(self, x, g, timemixstate: TimeMixState):
|
||
B, T, C = x.size()
|
||
x = x.view(B * T, C)
|
||
|
||
x = self.ln_x(x).view(B, T, C)
|
||
x = self.output(x * g)
|
||
return x, timemixstate
|
||
|
||
def forward(self, x, last_state: TimeMixState):
|
||
B, T, C = x.size()
|
||
H = self.n_head
|
||
shift_state = last_state.shift_state
|
||
r, k, v, g, w, lx = self.jit_func(x, shift_state)
|
||
######
|
||
wkv_state = last_state.wkv_state.clone().contiguous()
|
||
x, wkv_state = RUN_CUDA_RWKV6_STATE(
|
||
B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
|
||
)
|
||
# wkv_state = last_state.wkv_state
|
||
return self.jit_func_2(x, g, TimeMixState(lx, wkv_state))
|
||
|
||
|
||
class RWKV_CMix_x060_infctx(MyModule):
|
||
def __init__(self, args, layer_id):
|
||
super().__init__()
|
||
self.args = args
|
||
self.layer_id = layer_id
|
||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||
|
||
with torch.no_grad(): # fancy init of time_mix
|
||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||
ddd = torch.ones(1, 1, args.n_embd)
|
||
for i in range(args.n_embd):
|
||
ddd[0, 0, i] = i / args.n_embd
|
||
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
||
|
||
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
|
||
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
|
||
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
|
||
|
||
@MyFunction
|
||
def forward(self, x, last_state: ChannelMixState):
|
||
xx = torch.concat((last_state.shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
|
||
xk = x + xx * self.time_maa_k
|
||
xr = x + xx * self.time_maa_r
|
||
|
||
k = self.key(xk)
|
||
k = torch.relu(k) ** 2
|
||
kv = self.value(k)
|
||
return torch.sigmoid(self.receptance(xr)) * kv, ChannelMixState(x[:, -1])
|
||
|
||
|
||
########################################################################################################
|
||
# The RWKV Model with our blocks
|
||
########################################################################################################
|
||
|
||
|
||
class Block(nn.Module):
|
||
def __init__(self, args, layer_id):
|
||
super().__init__()
|
||
self.args = args
|
||
self.layer_id = layer_id
|
||
|
||
self.ln1 = nn.LayerNorm(args.n_embd)
|
||
self.ln2 = nn.LayerNorm(args.n_embd)
|
||
|
||
if self.layer_id == 0:
|
||
self.ln0 = nn.LayerNorm(args.n_embd)
|
||
if args.my_pos_emb > 0:
|
||
self.pos_emb_x = nn.Parameter(
|
||
torch.zeros((1, args.my_pos_emb, args.n_embd))
|
||
)
|
||
self.pos_emb_y = nn.Parameter(
|
||
torch.zeros((args.my_pos_emb, 1, args.n_embd))
|
||
)
|
||
|
||
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
||
self.ffnPre = RWKV_ChannelMix(args, 0)
|
||
else:
|
||
if "x060" in os.environ["RWKV_MY_TESTING"]:
|
||
if os.environ["RWKV_TRAIN_TYPE"] == "states":
|
||
self.att = RWKV_Tmix_x060_state(args, layer_id)
|
||
elif os.environ["RWKV_TRAIN_TYPE"] == "infctx":
|
||
self.att = RWKV_Tmix_x060_infctx(args, layer_id)
|
||
else:
|
||
self.att = RWKV_Tmix_x060(args, layer_id)
|
||
else:
|
||
self.att = RWKV_TimeMix_RWKV5(args, layer_id)
|
||
|
||
if "g" in os.environ["RWKV_MY_TESTING"]:
|
||
self.ffn = MishGLU(args, layer_id)
|
||
else:
|
||
if "x060" in os.environ["RWKV_MY_TESTING"]:
|
||
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
|
||
self.ffn = RWKV_CMix_x060_infctx(args, layer_id)
|
||
else:
|
||
self.ffn = RWKV_CMix_x060(args, layer_id)
|
||
else:
|
||
self.ffn = RWKV_ChannelMix(args, layer_id)
|
||
|
||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
||
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
||
self.register_buffer(
|
||
"tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||
)
|
||
|
||
if args.dropout > 0:
|
||
self.drop0 = nn.Dropout(p=args.dropout)
|
||
self.drop1 = nn.Dropout(p=args.dropout)
|
||
|
||
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
|
||
|
||
def forward(self, x, last_state: BlockState, x_emb=None):
|
||
args = self.args
|
||
B, T, C = x.size()
|
||
if self.layer_id == 0:
|
||
x = self.ln0(x)
|
||
if args.my_pos_emb > 0:
|
||
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[
|
||
:-1, :
|
||
]
|
||
x = x + pos_emb
|
||
|
||
if self.args.dropout == 0:
|
||
if self.layer_id == 0 and args.pre_ffn > 0:
|
||
x = x + self.ffnPre(self.ln1(x))
|
||
else:
|
||
att_out, att_state = self.att(
|
||
self.ln1(x), last_state.time_mix_state
|
||
)
|
||
x = x + att_out
|
||
ffn_out, fnn_state = self.ffn(self.ln2(x), last_state.channel_mix_state)
|
||
x = x + ffn_out
|
||
else:
|
||
if self.layer_id == 0 and args.pre_ffn > 0:
|
||
x = self.drop0(x + self.ffnPre(self.ln1(x)))
|
||
else:
|
||
x = self.drop0(x + self.att(self.ln1(x)))
|
||
x = self.drop1(x + self.ffn(self.ln2(x)))
|
||
|
||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||
xx = self.tiny_ln(x)
|
||
q = self.tiny_q(xx)[:, :T, :]
|
||
k = self.tiny_k(xx)[:, :T, :]
|
||
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
||
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
||
x = x + c @ self.tiny_v(x_emb)
|
||
return x, BlockState(att_state, fnn_state)
|
||
|
||
else:
|
||
|
||
def forward(self, x, x_emb=None):
|
||
args = self.args
|
||
B, T, C = x.size()
|
||
if self.layer_id == 0:
|
||
x = self.ln0(x)
|
||
if args.my_pos_emb > 0:
|
||
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[
|
||
:-1, :
|
||
]
|
||
x = x + pos_emb
|
||
|
||
if self.args.dropout == 0:
|
||
if self.layer_id == 0 and args.pre_ffn > 0:
|
||
x = x + self.ffnPre(self.ln1(x))
|
||
else:
|
||
x = x + self.att(self.ln1(x))
|
||
x = x + self.ffn(self.ln2(x))
|
||
else:
|
||
if self.layer_id == 0 and args.pre_ffn > 0:
|
||
x = self.drop0(x + self.ffnPre(self.ln1(x)))
|
||
else:
|
||
x = self.drop0(x + self.att(self.ln1(x)))
|
||
x = self.drop1(x + self.ffn(self.ln2(x)))
|
||
|
||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||
xx = self.tiny_ln(x)
|
||
q = self.tiny_q(xx)[:, :T, :]
|
||
k = self.tiny_k(xx)[:, :T, :]
|
||
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
||
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
||
x = x + c @ self.tiny_v(x_emb)
|
||
return x
|
||
|
||
|
||
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
|
||
|
||
class L2Wrap(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, loss, y, token_amount):
|
||
ctx.save_for_backward(y)
|
||
ctx.token_amount = token_amount
|
||
return loss
|
||
|
||
@staticmethod
|
||
def backward(
|
||
ctx, grad_output
|
||
): # 这个函数会不会影响batch和grad_accu的一致性?感觉上会。梯度累积时,factor变大了。但是只有loss缩放,这里的正则化项反而没有缩放
|
||
y = ctx.saved_tensors[0]
|
||
# to encourage the logits to be close to 0
|
||
if ctx.token_amount == 0:
|
||
return (grad_output, None, None)
|
||
factor = 1e-4 / ctx.token_amount # 这一行类似crossentropy在token上平均。
|
||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||
gy = torch.zeros_like(y)
|
||
if os.environ.get("WN_FIX_L2WRAP"): # 实现batch等价性
|
||
# maxx[maxx<3.]=0. #防止对已经较小的logits值下拉,只对大于阈值的往下拉
|
||
gy.scatter_(-1, ids, maxx * factor * grad_output)
|
||
else:
|
||
gy.scatter_(-1, ids, maxx * factor)
|
||
return (grad_output, gy, None)
|
||
|
||
else:
|
||
|
||
class L2Wrap(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, loss, y):
|
||
ctx.save_for_backward(y)
|
||
return loss
|
||
|
||
@staticmethod
|
||
def backward(ctx, grad_output):
|
||
y = ctx.saved_tensors[0]
|
||
# to encourage the logits to be close to 0
|
||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||
gy = torch.zeros_like(y)
|
||
gy.scatter_(-1, ids, maxx * factor)
|
||
return (grad_output, gy)
|
||
|
||
|
||
class RWKV(pl.LightningModule):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.args = args
|
||
if not hasattr(args, "dim_att"):
|
||
args.dim_att = args.n_embd
|
||
if not hasattr(args, "dim_ffn"):
|
||
args.dim_ffn = args.n_embd * 4
|
||
if not hasattr(args, "tiny_att_layer"):
|
||
args.tiny_att_layer = -1
|
||
if not hasattr(args, "tiny_att_dim"):
|
||
args.tiny_att_dim = -1
|
||
assert args.n_embd % 32 == 0
|
||
assert args.dim_att % 32 == 0
|
||
assert args.dim_ffn % 32 == 0
|
||
|
||
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
||
|
||
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
||
|
||
self.ln_out = nn.LayerNorm(args.n_embd)
|
||
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||
|
||
if args.head_qk > 0:
|
||
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||
self.register_buffer(
|
||
"copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))
|
||
)
|
||
if args.dropout > 0:
|
||
self.drop0 = nn.Dropout(p=args.dropout)
|
||
|
||
def configure_optimizers(self):
|
||
args = self.args
|
||
|
||
lr_decay = set()
|
||
lr_1x = set()
|
||
lr_2x = set()
|
||
lr_3x = set()
|
||
for n, p in self.named_parameters():
|
||
if not p.requires_grad:
|
||
continue
|
||
if (("_w1" in n) or ("_w2" in n)) and (args.layerwise_lr > 0):
|
||
lr_1x.add(n)
|
||
elif (("time_mix" in n) or ("time_maa" in n)) and (args.layerwise_lr > 0):
|
||
if args.my_pile_stage == 2:
|
||
lr_2x.add(n)
|
||
else:
|
||
lr_1x.add(n)
|
||
elif (("time_decay" in n) or ("time_daaaa" in n)) and (
|
||
args.layerwise_lr > 0
|
||
):
|
||
if args.my_pile_stage == 2:
|
||
lr_3x.add(n)
|
||
else:
|
||
lr_2x.add(n)
|
||
elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
|
||
if args.my_pile_stage == 2:
|
||
lr_2x.add(n)
|
||
else:
|
||
lr_1x.add(n)
|
||
elif ("time_first" in n) and (args.layerwise_lr > 0):
|
||
lr_3x.add(n)
|
||
elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
|
||
lr_decay.add(n)
|
||
else:
|
||
lr_1x.add(n)
|
||
|
||
lr_decay = sorted(list(lr_decay))
|
||
lr_1x = sorted(list(lr_1x))
|
||
lr_2x = sorted(list(lr_2x))
|
||
lr_3x = sorted(list(lr_3x))
|
||
# print('decay', lr_decay)
|
||
# print('1x', lr_1x)
|
||
# print('2x', lr_2x)
|
||
# print('3x', lr_3x)
|
||
param_dict = {n: p for n, p in self.named_parameters()}
|
||
|
||
if args.layerwise_lr > 0:
|
||
if args.my_pile_stage == 2:
|
||
optim_groups = [
|
||
{
|
||
"params": [param_dict[n] for n in lr_1x],
|
||
"weight_decay": 0.0,
|
||
"my_lr_scale": 1.0,
|
||
},
|
||
{
|
||
"params": [param_dict[n] for n in lr_2x],
|
||
"weight_decay": 0.0,
|
||
"my_lr_scale": 5.0,
|
||
}, # test: 2e-3 / args.lr_init},
|
||
{
|
||
"params": [param_dict[n] for n in lr_3x],
|
||
"weight_decay": 0.0,
|
||
"my_lr_scale": 5.0,
|
||
}, # test: 3e-3 / args.lr_init},
|
||
]
|
||
else:
|
||
optim_groups = [
|
||
{
|
||
"params": [param_dict[n] for n in lr_1x],
|
||
"weight_decay": 0.0,
|
||
"my_lr_scale": 1.0,
|
||
},
|
||
{
|
||
"params": [param_dict[n] for n in lr_2x],
|
||
"weight_decay": 0.0,
|
||
"my_lr_scale": 2.0,
|
||
},
|
||
{
|
||
"params": [param_dict[n] for n in lr_3x],
|
||
"weight_decay": 0.0,
|
||
"my_lr_scale": 3.0,
|
||
},
|
||
]
|
||
else:
|
||
optim_groups = [
|
||
{
|
||
"params": [param_dict[n] for n in lr_1x],
|
||
"weight_decay": 0.0,
|
||
"my_lr_scale": 1.0,
|
||
}
|
||
]
|
||
|
||
if args.weight_decay > 0:
|
||
optim_groups += [
|
||
{
|
||
"params": [param_dict[n] for n in lr_decay],
|
||
"weight_decay": args.weight_decay,
|
||
"my_lr_scale": 1.0,
|
||
}
|
||
]
|
||
if self.deepspeed_offload:
|
||
return DeepSpeedCPUAdam(
|
||
optim_groups,
|
||
lr=self.args.lr_init,
|
||
betas=self.args.betas,
|
||
eps=self.args.adam_eps,
|
||
bias_correction=True,
|
||
adamw_mode=True,
|
||
amsgrad=False,
|
||
)
|
||
return FusedAdam(
|
||
optim_groups,
|
||
lr=self.args.lr_init,
|
||
betas=self.args.betas,
|
||
eps=self.args.adam_eps,
|
||
bias_correction=True,
|
||
adam_w_mode=True,
|
||
amsgrad=False,
|
||
)
|
||
else:
|
||
if self.deepspeed_offload:
|
||
return DeepSpeedCPUAdam(
|
||
optim_groups,
|
||
lr=self.args.lr_init,
|
||
betas=self.args.betas,
|
||
eps=self.args.adam_eps,
|
||
bias_correction=True,
|
||
adamw_mode=False,
|
||
weight_decay=0,
|
||
amsgrad=False,
|
||
)
|
||
return FusedAdam(
|
||
optim_groups,
|
||
lr=self.args.lr_init,
|
||
betas=self.args.betas,
|
||
eps=self.args.adam_eps,
|
||
bias_correction=True,
|
||
adam_w_mode=False,
|
||
weight_decay=0,
|
||
amsgrad=False,
|
||
)
|
||
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
||
|
||
@property
|
||
def deepspeed_offload(self) -> bool:
|
||
strategy = self.trainer.strategy
|
||
if isinstance(strategy, DeepSpeedStrategy):
|
||
cfg = strategy.config["zero_optimization"]
|
||
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
||
return False
|
||
|
||
if os.environ["RWKV_TRAIN_TYPE"] == "infctx":
|
||
|
||
def forward(
|
||
self, idx, last_shift_states: torch.Tensor, last_wkv_states: torch.Tensor
|
||
):
|
||
args = self.args
|
||
B, T = idx.size()
|
||
assert T <= args.chunk_ctx, "Cannot forward, model ctx_len is exhausted."
|
||
C = args.n_embd
|
||
H = args.dim_att // args.head_size_a
|
||
assert C == H * args.head_size_a
|
||
|
||
x = self.emb(idx)
|
||
x_emb = x
|
||
new_states = BlockStateList.empty(
|
||
args.n_layer, B, args.n_embd, H, x.device, x.dtype
|
||
)
|
||
if args.dropout > 0:
|
||
x = self.drop0(x)
|
||
|
||
for i, (block, block_state) in enumerate(
|
||
zip(self.blocks, BlockStateList(last_shift_states, last_wkv_states))
|
||
):
|
||
# x = x.to(block.device)
|
||
if args.grad_cp == 1 and i > 0: # and i < len(self.blocks)-1
|
||
x, new_block_state = torch_checkpoint(
|
||
block, x, block_state, use_reentrant=False
|
||
)
|
||
else:
|
||
x, new_block_state = block(x, block_state)
|
||
new_states[i] = new_block_state
|
||
|
||
x = self.ln_out(x)
|
||
|
||
if args.head_qk > 0:
|
||
q = self.head_q(x)[:, :T, :]
|
||
k = self.head_k(x)[:, :T, :]
|
||
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||
|
||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
||
|
||
x = self.head(x) + c
|
||
else:
|
||
x = self.head(x)
|
||
|
||
return x, new_states.shift_states, new_states.wkv_states
|
||
|
||
def training_step(self, batch, batch_idx):
|
||
args = self.args
|
||
T_train = args.chunk_ctx
|
||
idx, targets = batch
|
||
B, T = idx.shape
|
||
C = args.n_embd
|
||
H = args.dim_att // args.head_size_a
|
||
assert C == H * args.head_size_a
|
||
states = BlockStateList.create(
|
||
args.n_layer, B, C, H, idx.device, self.emb.weight.dtype
|
||
)
|
||
|
||
def checkpointed_step(
|
||
idx,
|
||
targets,
|
||
prev_loss,
|
||
last_shift_states,
|
||
last_wkv_states,
|
||
prev_token_amount,
|
||
):
|
||
logits, new_shift_states, new_wkv_states = self(
|
||
idx, last_shift_states, last_wkv_states
|
||
)
|
||
current_token_amount = (targets != -100).sum() # 这样是不是更合适?
|
||
current_token_amount = idx.shape[1]
|
||
if current_token_amount == 0:
|
||
loss = F.cross_entropy(
|
||
logits.view(-1, logits.size(-1)),
|
||
targets.reshape(-1),
|
||
reduction="sum",
|
||
)
|
||
else:
|
||
loss = F.cross_entropy(
|
||
logits.view(-1, logits.size(-1)), targets.reshape(-1)
|
||
)
|
||
loss = L2Wrap.apply(loss, logits, current_token_amount)
|
||
new_token_amount = prev_token_amount + current_token_amount
|
||
if new_token_amount > 0:
|
||
new_loss = prev_loss * (
|
||
prev_token_amount / new_token_amount
|
||
) + loss * (current_token_amount / new_token_amount)
|
||
else:
|
||
new_loss = prev_loss
|
||
|
||
return new_loss, new_shift_states, new_wkv_states, new_token_amount
|
||
|
||
total_loss = torch.tensor(0.0, dtype=self.emb.weight.dtype).requires_grad_()
|
||
token_amount = 0
|
||
i = 0
|
||
for i in range(math.ceil(T / T_train)):
|
||
# states.shift_states = states.shift_states.cuda()
|
||
# states.wkv_states = states.wkv_states.cuda()
|
||
total_loss, new_shift_states, new_wkv_states, token_amount = (
|
||
torch_checkpoint(
|
||
checkpointed_step,
|
||
idx[:, i * T_train : (i + 1) * T_train],
|
||
targets[:, i * T_train : (i + 1) * T_train],
|
||
total_loss,
|
||
states.shift_states,
|
||
states.wkv_states,
|
||
token_amount,
|
||
use_reentrant=False,
|
||
)
|
||
)
|
||
# total_loss,new_shift_states, new_wkv_states,token_amount = checkpointed_step(
|
||
# idx[:, i * T_train:(i + 1) * T_train],
|
||
# targets[:, i * T_train:(i + 1) * T_train],
|
||
# total_loss,
|
||
# states.shift_states,
|
||
# states.wkv_states,
|
||
# token_amount
|
||
# )
|
||
# new_shift_states = new_shift_states.cpu()
|
||
# new_wkv_states = new_wkv_states.cpu()
|
||
states = BlockStateList(new_shift_states, new_wkv_states)
|
||
|
||
return total_loss
|
||
|
||
else:
|
||
|
||
def forward(self, idx):
|
||
args = self.args
|
||
B, T = idx.size()
|
||
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
||
|
||
x = self.emb(idx)
|
||
x_emb = x
|
||
|
||
if args.dropout > 0:
|
||
x = self.drop0(x)
|
||
if args.tiny_att_dim > 0:
|
||
for block in self.blocks:
|
||
if args.grad_cp == 1:
|
||
if args.lora or args.state_tune or args.train_type == "state":
|
||
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||
else:
|
||
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
||
else:
|
||
x = block(x, x_emb)
|
||
else:
|
||
for block in self.blocks:
|
||
if args.grad_cp == 1:
|
||
if args.lora or args.state_tune or args.train_type == "state":
|
||
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||
else:
|
||
x = deepspeed.checkpointing.checkpoint(block, x)
|
||
else:
|
||
x = block(x)
|
||
|
||
x = self.ln_out(x)
|
||
|
||
if args.head_qk > 0:
|
||
q = self.head_q(x)[:, :T, :]
|
||
k = self.head_k(x)[:, :T, :]
|
||
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||
|
||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
||
|
||
x = self.head(x) + c
|
||
else:
|
||
x = self.head(x)
|
||
|
||
return x
|
||
|
||
def training_step(self, batch, batch_idx):
|
||
args = self.args
|
||
if args.my_qa_mask != 1:
|
||
idx, targets = batch
|
||
logits = self(idx)
|
||
loss = F.cross_entropy(
|
||
logits.view(-1, logits.size(-1)), targets.view(-1)
|
||
)
|
||
# if '0' in os.environ["RWKV_MY_TESTING"]:
|
||
# print('logits', logits)
|
||
# torch.set_printoptions(threshold=10000)
|
||
# print('idx', idx)
|
||
# exit(0)
|
||
else:
|
||
idx, targets, mask = batch
|
||
mask = mask.view(-1)
|
||
sum_mask = torch.sum(mask).item()
|
||
# if sum_mask == 0:
|
||
# return torch.tensor([0.0], requires_grad=True)
|
||
|
||
logits = self(idx)
|
||
if sum_mask == mask.shape[0]:
|
||
loss = F.cross_entropy(
|
||
logits.view(-1, logits.size(-1)), targets.view(-1)
|
||
)
|
||
# print('rank', self.global_rank, 'loss', loss.item())
|
||
else:
|
||
loss = F.cross_entropy(
|
||
logits.view(-1, logits.size(-1)),
|
||
targets.view(-1),
|
||
reduction="none",
|
||
)
|
||
# loss_raw = loss
|
||
loss = torch.sum(loss * mask) / sum_mask
|
||
|
||
# torch.set_printoptions(threshold=10000)
|
||
# if True: #self.global_rank == 1:
|
||
# tmp = ''
|
||
# sss = 0
|
||
# ccc = 0
|
||
# for i in range(mask.shape[0]):
|
||
# if mask[i] > 0:
|
||
# tmp += str(idx.view(-1)[i].item()) + ','
|
||
# sss += loss_raw.view(-1)[i].float().item()
|
||
# ccc += 1
|
||
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
||
|
||
return L2Wrap.apply(loss, logits)
|
||
|
||
def training_step_end(self, batch_parts):
|
||
if pl.__version__[0] != "2":
|
||
all = self.all_gather(batch_parts)
|
||
if self.trainer.is_global_zero:
|
||
self.trainer.my_loss_all = all
|
||
|
||
def generate_init_weight(self):
|
||
print(
|
||
f"""
|
||
############################################################################
|
||
#
|
||
# Init model weight (slow for large models)...
|
||
#
|
||
############################################################################
|
||
"""
|
||
)
|
||
m = {}
|
||
for n in self.state_dict():
|
||
p = self.state_dict()[n]
|
||
shape = p.shape
|
||
|
||
gain = 1.0
|
||
scale = 1.0
|
||
if (
|
||
"ln_" in n
|
||
or ".ln" in n
|
||
or "time_" in n
|
||
or "_mask" in n
|
||
or "pos_emb" in n
|
||
or ".mask." in n
|
||
):
|
||
if "ln_x.weight" in n:
|
||
layer_scale = (1 + int(n.split(".")[1])) / self.args.n_layer
|
||
m[n] = (p * 0.0) + (layer_scale**0.7)
|
||
else:
|
||
m[n] = p
|
||
else:
|
||
if n == "emb.weight":
|
||
scale = -1 * self.args.lr_init
|
||
else:
|
||
if shape[0] > shape[1]:
|
||
gain = math.sqrt(shape[0] / shape[1])
|
||
|
||
zero = [
|
||
".att.output.",
|
||
".ffn.value.",
|
||
".ffn.receptance.",
|
||
".ffnPre.value.",
|
||
".ffnPre.receptance.",
|
||
"head_q.",
|
||
".oo.",
|
||
".rr.",
|
||
]
|
||
|
||
for kk in zero:
|
||
if kk in n:
|
||
scale = 0
|
||
if n == "head.weight":
|
||
scale = 0.5
|
||
if "head_k." in n:
|
||
scale = 0.1
|
||
if "head_q." in n:
|
||
scale = 0
|
||
|
||
print(
|
||
f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}"
|
||
)
|
||
|
||
if self.args.accelerator.upper() == "GPU":
|
||
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
||
else:
|
||
m[n] = torch.empty((shape[0], shape[1]))
|
||
|
||
if scale == 0:
|
||
nn.init.zeros_(m[n])
|
||
elif scale < 0:
|
||
nn.init.uniform_(m[n], a=scale, b=-scale)
|
||
else:
|
||
nn.init.orthogonal_(m[n], gain=gain * scale)
|
||
|
||
m[n] = m[n].cpu()
|
||
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||
m[n] = m[n].half()
|
||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||
m[n] = m[n].bfloat16()
|
||
|
||
# if n == "emb.weight":
|
||
# print(m[n])
|
||
|
||
gc.collect()
|
||
torch.cuda.empty_cache()
|
||
return m
|