This commit is contained in:
josc146
2024-05-28 22:35:47 +08:00
parent 3488d22d22
commit f05a4acb04
138 changed files with 29047 additions and 334 deletions

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution,
ShortConvolution)
from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate,
FusedLayerNormSwishGateLinear,
FusedRMSNormSwishGate,
FusedRMSNormSwishGateLinear)
from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm,
RMSNormLinear)
from fla.modules.rotary import RotaryEmbedding
__all__ = [
'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
'FusedCrossEntropyLoss',
'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
'RotaryEmbedding'
]

View File

@@ -0,0 +1,394 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2024, Tri Dao, Yu Zhang, Songlin Yang.
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from fla.utils import contiguous
sigmoid_fwd_codestring = """
template <typename T> T sigmoid_fwd(T x) {
return 1.0f / (1.0f + ::exp(-float(x)));
}
"""
sigmoid_bwd_codestring = """
template <typename T> T sigmoid_bwd(T x, T g) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
return float(g) * x_sigmoid * (1.0f - x_sigmoid);
}
"""
sigmoid_fwd = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring)
sigmoid_bwd = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring)
class SigmoidFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return sigmoid_fwd(x)
@staticmethod
def backward(ctx, dout):
x, = ctx.saved_tensors
return sigmoid_bwd(x, dout)
sigmoid = SigmoidFunction.apply
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 16}, num_warps=4),
triton.Config({'BT': 16}, num_warps=8),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=8),
triton.Config({'BT': 64}, num_warps=2),
triton.Config({'BT': 64}, num_warps=4),
triton.Config({'BT': 64}, num_warps=8),
triton.Config({'BT': 128}, num_warps=2),
triton.Config({'BT': 128}, num_warps=4),
triton.Config({'BT': 128}, num_warps=8),
triton.Config({'BT': 256}, num_warps=2),
triton.Config({'BT': 256}, num_warps=4),
triton.Config({'BT': 256}, num_warps=8)
],
key=['D']
)
@triton.jit
def logsigmoid_fwd_kernel(
x,
y,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr
):
i = tl.program_id(0)
o_i = i * BT + tl.arange(0, BT)
p_x = x + o_i
p_y = y + o_i
mask = o_i < T
# [D,]
b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)
b_m = tl.minimum(0., b_x)
b_z = 1. + tl.exp(-tl.abs(b_x))
b_y = b_m - tl.log(b_z)
tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 16}, num_warps=4),
triton.Config({'BT': 16}, num_warps=8),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=8),
triton.Config({'BT': 64}, num_warps=2),
triton.Config({'BT': 64}, num_warps=4),
triton.Config({'BT': 64}, num_warps=8),
triton.Config({'BT': 128}, num_warps=2),
triton.Config({'BT': 128}, num_warps=4),
triton.Config({'BT': 128}, num_warps=8),
triton.Config({'BT': 256}, num_warps=2),
triton.Config({'BT': 256}, num_warps=4),
triton.Config({'BT': 256}, num_warps=8)
],
key=['D']
)
@triton.jit
def logsigmoid_bwd_kernel(
x,
dx,
dy,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr
):
i = tl.program_id(0)
o_i = i * BT + tl.arange(0, BT)
p_x = x + o_i
p_dx = dx + o_i
p_dy = dy + o_i
mask = o_i < T
# [D,]
b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)
b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32)
b_dx = b_dy * (1. - tl.sigmoid(b_x))
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
class LogSigmoidFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, x):
T, D = x.numel(), x.shape[-1]
y = torch.empty_like(x)
logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D)
ctx.save_for_backward(x,)
return y
@staticmethod
@contiguous
def backward(ctx, dy):
x, = ctx.saved_tensors
T, D = x.numel(), x.shape[-1]
dx = torch.empty_like(x)
logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D)
return dx
logsigmoid = LogSigmoidFunction.apply
swish_fwd_codestring = """
template <typename T> T swish_fwd(T x) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
return float(x) * x_sigmoid;
}
"""
swish_bwd_codestring = """
template <typename T> T swish_bwd(T x, T g) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x));
}
"""
swish_fwd = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring)
swish_bwd = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring)
class SwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_fwd(x)
@staticmethod
def backward(ctx, dout):
x, = ctx.saved_tensors
return swish_bwd(x, dout)
swish = SwishFunction.apply
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(y, bias):
x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_bwd(g, y, bias):
"""Assume that y has shape (B, D) and bias has shape (D)"""
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_bwd(grad_output, input, bias)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return (ff * g).to(dtype=x.dtype)
class FastGeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return gelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input)
return tmp
fast_gelu_impl = FastGeLUFunction.apply
@torch.jit.script
def relu_bwd(g, x):
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_fwd(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_bwd(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
class SquaredReLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return sqrelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return sqrelu_bwd(grad_output, input)
sqrelu = SquaredReLUFunction.apply
swiglu_fwd_codestring = """
template <typename T> T swiglu_fwd(T x, T y) {
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
}
"""
swiglu_bwd_codestring = """
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = float(x) * x_sigmoid * float(g);
}
"""
swiglu_bwd_with_output_codestring = """
template <typename T> T swiglu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
float x_swish = float(x) * x_sigmoid;
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = x_swish * float(g);
z = x_swish * float(y);
}
"""
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
swiglu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_with_output_codestring, num_outputs=3)
class SwiGLUFunction(torch.autograd.Function):
r"""
Swish-Gated Linear Unit (SwiGLU) function.
.. math::
\text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y
"""
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return swiglu_fwd(x, y)
@staticmethod
def backward(ctx, dout):
x, y = ctx.saved_tensors
return swiglu_bwd(x, y, dout)
class SwiGLULinearFunction(torch.autograd.Function):
r"""
Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation.
.. math::
\text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b
This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory.
"""
@staticmethod
def forward(ctx, x, y, weight, bias):
z = swiglu_fwd(x, y)
out = F.linear(z.to(weight.dtype), weight, bias)
# We don't store z, will be recomputed in the backward pass to save memory
ctx.save_for_backward(x, y, weight)
ctx.linear_bias_is_none = bias is None
return out
@staticmethod
def backward(ctx, dout, *args):
x, y, weight = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dz = F.linear(dout, weight.t()).view_as(x)
dx, dy, z = swiglu_bwd_with_output(x, y, dz)
dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1]))
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
return dx, dy, dlinear_weight, dlinear_bias
swiglu = SwiGLUFunction.apply
swiglu_linear = SwiGLULinearFunction.apply
ACT2FN = {
'relu': F.relu,
'sigmoid': sigmoid,
'logsigmoid': logsigmoid,
'silu': swish,
'swish': swish,
'sqrelu': sqrelu,
'gelu': fast_gelu_impl,
'bias_gelu': bias_gelu_impl,
}

View File

@@ -0,0 +1,336 @@
# -*- coding: utf-8 -*-
# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
import math
import warnings
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from fla.modules.activations import ACT2FN
from fla.utils import checkpoint
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn = None
causal_conv1d_update = None
def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
if k_rev is not None:
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
k_f = k_f + k_rev_f.conj()
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3:
k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
out = y + u
if gelu:
out = F.gelu(out)
if dropout_mask is not None:
return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
else:
return out.to(dtype=u.dtype)
@checkpoint
def proj_then_conv1d(
x: torch.Tensor,
proj_weight: torch.Tensor,
conv1d_weight: torch.Tensor,
conv1d_bias: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
# We do matmul and transpose BLH -> HBL at the same time
x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2])
if causal_conv1d_fn is None:
raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
if cache is None:
x = causal_conv1d_fn(
x=x,
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
bias=conv1d_bias,
activation="silu",
).transpose(1, 2)
else:
assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
x = x.squeeze(-1)
x = causal_conv1d_update(
x=x,
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
bias=conv1d_bias,
cache=cache,
activation="silu",
)
return x
class ShortConvolution(nn.Conv1d):
"""
Simple wrapper around `nn.Conv1d` that accepts dimension last.
"""
def __init__(
self,
hidden_size: int,
kernel_size: int,
bias: bool = False,
activation: Optional[str] = 'silu',
use_causal_conv: Optional[bool] = True
):
super().__init__(in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=kernel_size,
groups=hidden_size,
bias=bias,
padding=kernel_size - 1)
self.hidden_size = hidden_size
self.activation = None
if activation is not None:
assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
self.activation = activation
if use_causal_conv:
if causal_conv1d_fn is None:
warnings.warn("Please install `causal-conv1d` to use causal convolutions, setting `use_causal_conv` to False.")
use_causal_conv = False
self.use_causal_conv = use_causal_conv
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.padding_mode != 'zeros':
s += ', padding_mode={padding_mode}'
if self.activation is not None:
s += ', activation={activation}'
if not self.use_causal_conv:
s += ', use_causal_conv={use_causal_conv}'
return s.format(**self.__dict__)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x (`torch.Tensor`):
Tensor of shape `[batch_size, seq_len, hidden_size]`
mask (`Optional[torch.Tensor]`):
Attention mask dealing with padded positions.
cache (`Optional[torch.Tensor]`):
Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`,
Returns:
Tensor of shape `[batch_size, seq_len, hidden_size]`. The `cache` (if provided) is updated inplace.
"""
if mask is not None:
x = x.mul_(mask.unsqueeze(-1))
if cache is not None and x.shape[1] == 1:
return self.step(x, cache)
x = rearrange(x, "b l d -> b d l")
# Update state (B D W)
if cache is not None:
cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0)))
if self.use_causal_conv:
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.weight, "d 1 w -> d w"),
bias=self.bias,
activation=self.activation,
)
else:
x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
if self.activation is not None:
x = ACT2FN[self.activation](x)
return rearrange(x, "b d l -> b l d")
def step(
self,
x: torch.Tensor,
cache: torch.Tensor
):
assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now"
x = x.squeeze(1)
if self.use_causal_conv:
x = causal_conv1d_update(
x=x,
conv_state=cache,
weight=rearrange(self.weight, "d 1 w -> d w"),
bias=self.bias,
activation=self.activation,
)
else:
dtype = x.dtype
cache.copy_(torch.roll(cache, shifts=-1, dims=-1))
cache[:, :, -1] = x
x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
if self.bias is not None:
x = x + self.bias
if self.activation is not None:
x = ACT2FN[self.activation](x).to(dtype=dtype)
return x.unsqueeze(1)
@property
def state_size(self) -> int:
return self.hidden_size * self.kernel_size
class LongConvolution(nn.Module):
"""
LongConvolution applies a convolution operation on the input tensor using a fixed
filter of length l_max.
The filter is learned during training and is applied using FFT convolution.
Args:
hidden_size (int): The number of expected features in the input and output.
l_max (int): The maximum sequence length.
Returns:
y: (b, l, d) tensor
"""
def __init__(
self,
hidden_size: int,
l_max: int,
**kwargs,
):
"""
Initializes the LongConvolution module.
Args:
hidden_size (int): The number of expected features in the input and output.
l_max (int): The maximum sequence length.
"""
super().__init__()
self.hidden_size = hidden_size
self.filter = nn.Parameter(torch.randn(self.hidden_size, l_max), requires_grad=True)
def forward(self, x: torch.Tensor, *args, **kwargs):
"""
Applies the LongConvolution operation on the input tensor.
Args:
x: (b, l, d) tensor
Returns:
y: (b, l, d) tensor
"""
x = x.transpose(1, 2)
y = fft_conv(x, self.filter, dropout_mask=None, gelu=False)
y = y.transpose(1, 2)
return y.to(dtype=x.dtype)
class PositionalEmbedding(nn.Module):
def __init__(self, emb_dim: int, seq_len: int, **kwargs):
"""Complex exponential positional embeddings for implicit long convolution filters."""
super().__init__()
self.seq_len = seq_len
# The time embedding fed to the filteres is normalized so that t_f = 1
t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
if emb_dim > 1:
bands = (emb_dim - 1) // 2
# To compute the right embeddings we use the "proper" linspace
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
z = torch.exp(-1j * f * w)
z = torch.cat([t, z.real, z.imag], dim=-1)
self.z = nn.Parameter(z, requires_grad=False)
def forward(self, L):
return self.z[:, :L]
class ImplicitLongConvolution(nn.Module):
"""
Long convolution with implicit filter parameterized by an MLP.
Args:
hidden_size (int):
The number of expected features in the input and output.
l_max (int):
The maximum sequence length.
d_emb (Optional[int]):
The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine).
Defaults to 3.
d_hidden (Optional[int]):
The number of features in the hidden layer of the MLP. Defaults to 16.
Attributes:
pos_emb (`PositionalEmbedding`): The positional embedding layer.
mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter.
"""
def __init__(
self,
hidden_size: int,
l_max: int,
d_emb: int = 3,
d_hidden: int = 16,
**kwargs,
):
"""
Long convolution with implicit filter parameterized by an MLP.
"""
super().__init__()
self.hidden_size = hidden_size
self.d_emb = d_emb
assert (
d_emb % 2 != 0 and d_emb >= 3
), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)"
self.pos_emb = PositionalEmbedding(d_emb, l_max)
# final linear layer
self.mlp = nn.Sequential(
nn.Linear(d_emb, d_hidden),
torch.nn.ReLU(),
nn.Linear(d_hidden, hidden_size),
)
def filter(self, seq_len: int, *args, **kwargs):
k = self.mlp(self.pos_emb(seq_len))
return k.transpose(1, 2)
def forward(self, x: torch.Tensor, *args, **kwargs):
"""
Args:
x: (b, l, d) tensor
Returns:
y: (b, l, d) tensor
"""
x = x.transpose(1, 2)
k = self.filter(x.shape[-1])
y = fft_conv(x, k, dropout_mask=None, gelu=False)
y = y.transpose(1, 2)
return y.to(dtype=x.dtype)

View File

@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from fla.modules.layernorm import layer_norm_fn
from fla.utils import checkpoint
@checkpoint
def flatten_diag_outer_product(x, y):
z = torch.einsum("...i,...j->...ij", x, y)
N = z.size(-1)
indicies = torch.triu_indices(N, N)
return z[..., indicies[0], indicies[1]]
@checkpoint
def flatten_diag_outer_product_off1(x, y):
z = torch.einsum("...i,...j->...ij", x, y)
N = z.size(-1)
indicies = torch.triu_indices(N, N, 1)
indices2 = torch.arange(0, N)
return z[..., indicies[0], indicies[1]], z[..., indices2, indices2]
def is_power_of_2(n):
return (n & (n - 1) == 0) and n != 0
class HedgehogFeatureMap(nn.Module):
r"""
Hedgehog feature map as introduced in
`The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry <https://arxiv.org/abs/2402.04347>`_
"""
def __init__(
self,
head_dim: int
) -> HedgehogFeatureMap:
super().__init__()
# Trainable map
self.layer = nn.Linear(head_dim, head_dim)
self.init_weights_()
def init_weights_(self):
"""Initialize trainable map as identity"""
with torch.no_grad():
identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float)
self.layer.weight.copy_(identity.to(self.layer.weight))
nn.init.zeros_(self.layer.bias)
def forward(self, x: torch.Tensor):
x = self.layer(x) # shape b, h, l, d
return torch.cat([2*x, -2*x], dim=-1).softmax(-1)
class T2RFeatureMap(nn.Module):
r"""
Simple linear mapping feature map as in
`Finetuning Pretrained Transformers into RNNs <https://arxiv.org/abs/2103.13076>`_
"""
def __init__(
self,
head_dim: int,
dot_dim: int = None
) -> T2RFeatureMap:
super().__init__()
# Trainable map
if dot_dim is None:
dot_dim = head_dim
self.layer = nn.Linear(head_dim, dot_dim)
def forward(self, x: torch.Tensor):
return self.layer(x).relu()
class DPFPFeatureMap(nn.Module):
r"""
Deterministic Parameter-Free Projection (DPFP) feature map in
`Linear Transformers Are Secretly Fast Weight Programmers <https://arxiv.org/abs/2102.11174>`_
"""
def __init__(
self,
head_dim: int,
nu: int = 4
) -> DPFPFeatureMap:
super().__init__()
self.nu = nu
def forward(self, x: torch.Tensor):
x = torch.cat([x.relu(), -x.relu()], dim=-1)
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1)
x_repeat = torch.cat([x] * self.nu, dim=-1)
return x_repeat * x_rolled
class HadamardFeatureMap(nn.Module):
def __init__(
self,
head_dim: int
) -> HadamardFeatureMap:
super().__init__()
# Trainable map
self.layer1 = nn.Linear(head_dim, head_dim)
self.layer2 = nn.Linear(head_dim, head_dim)
def forward(self, x: torch.Tensor):
return self.layer1(x) * self.layer2(x)
class LearnableOuterProductFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
feature_dim: int
) -> LearnableOuterProductFeatureMap:
super().__init__()
# Trainable map
self.layer1 = nn.Linear(head_dim, feature_dim, bias=False)
self.layer2 = nn.Linear(head_dim, feature_dim, bias=False)
self.normalizer = feature_dim ** -0.5
def forward(self, x: torch.Tensor):
return flatten_diag_outer_product(self.layer1(x), self.layer2(x))
class LearnablePolySketchNonNegativeFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
sketch_size: Optional[int] = None,
degree: Optional[int] = 2
) -> LearnablePolySketchNonNegativeFeatureMap:
super().__init__()
assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2"
self.head_dim = head_dim
self.sketch_size = sketch_size if sketch_size is not None else head_dim
self.degree = degree
self.gamma = nn.Parameter(torch.ones(head_dim))
self.beta = nn.Parameter(torch.zeros(head_dim))
# NOTE: the sketch layers defined here are quite different from the original paper
# currently we simply use linear layers without any non-linear activations
self.sketches1 = nn.ModuleList([
nn.Linear(head_dim, sketch_size, bias=False),
*[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
])
self.sketches2 = nn.ModuleList([
nn.Linear(head_dim, sketch_size, bias=False),
*[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
])
def forward(self, x: torch.Tensor):
# Section 2.1
x = layer_norm_fn(x, self.gamma, self.beta)
# first map the input to sketch size with learnable parameters
x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5
for i in range(1, int(math.log2(self.degree)) - 1):
x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5
# do sketch mapping for log2(p) - 1 times in total
# do p=2 mapping to ensure non-negativity
return flatten_diag_outer_product(x, x)
class TaylorFeatureMap(nn.Module):
def __init__(
self,
head_dim: int
) -> TaylorFeatureMap:
super().__init__()
self.head_dim = head_dim
self.r2 = math.sqrt(2)
self.rd = math.sqrt(self.head_dim)
self.rrd = math.sqrt(self.rd)
def forward(self, x: torch.Tensor):
x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1)
class RebasedFeatureMap(nn.Module):
def __init__(
self,
head_dim: int,
use_gamma: Optional[bool] = True,
use_beta: Optional[bool] = True,
normalize: Optional[bool] = True
) -> RebasedFeatureMap:
super().__init__()
self.head_dim = head_dim
self.use_gamma = use_gamma
self.use_beta = use_beta
self.normalize = normalize
self.gamma = None
self.beta = None
if use_gamma:
self.gamma = nn.Parameter(torch.ones(head_dim))
if use_beta:
self.beta = nn.Parameter(torch.zeros(head_dim))
def forward(self, x: torch.Tensor, flatten: Optional[bool] = True):
if self.use_beta and self.use_gamma and self.normalize:
x = layer_norm_fn(x, self.gamma, self.beta)
elif self.normalize:
x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta)
elif self.use_gamma and self.use_beta:
x = torch.addcmul(self.beta, x, self.gamma)
elif self.use_gamma:
x = x.mul(self.gamma)
else:
raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, "
f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)")
if not flatten:
return x
x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
# rebased use learnable parameters to approximate any quadratic function
return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1)

View File

@@ -0,0 +1,398 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
from typing import Tuple
import torch
import torch.nn as nn
import triton
import triton.language as tl
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_fwd_kernel(
loss_ptr, # data ptrs
lse_ptr,
z_loss_ptr,
logits_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
n_rows,
logits_row_stride, # strides
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
max_logits = tl.max(logits, 0)
if HAS_SMOOTHING:
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
if label_idx == ignored_index:
loss = 0.0
z_loss = 0.0
else:
label_idx -= class_start_idx
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
n_cols, (col_block_idx + 1) * BLOCK_SIZE
):
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
if HAS_SMOOTHING:
loss = (
(lse if not SPLIT else 0.0)
- smoothing * sum_logits / total_classes
- (1 - smoothing) * logits_label
)
else:
loss = (lse if not SPLIT else 0.0) - logits_label
else:
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
if HAS_SMOOTHING:
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
else:
loss = 0.0
if not SPLIT:
z_loss = lse_square_scale * lse * lse
loss += z_loss
else:
z_loss = 0.0
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
if not SPLIT:
tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_bwd_kernel(
dlogits_ptr, # data ptrs
dloss_ptr,
logits_ptr,
lse_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
logits_row_stride, # strides
dlogits_row_stride,
dloss_row_stride,
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx != ignored_index:
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs += 2.0 * lse_square_scale * lse * probs
label_idx -= class_start_idx
if HAS_SMOOTHING:
smooth_negative = smoothing / total_classes
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
else:
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
class CrossEntropyLossFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
labels,
smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
n_rows, n_cols = logits.shape
assert labels.shape == (n_rows,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
total_classes = world_size * n_cols
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
class_start_idx = rank * n_cols
if logits.stride(-1) != 1:
logits = logits.contiguous()
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
MAX_BLOCK_SIZE = 64 * 1024
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
num_warps = (
4
if BLOCK_SIZE < 2048
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
)
# We may split the lse computation across multiple blocks, then do a reduction
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
# where having just one thread block processing more than 64k elements is slow.
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_fwd_kernel[(n_rows, n_splits)](
losses, # data ptrs
lse,
z_losses,
logits,
labels,
smoothing,
logit_scale,
lse_square_scale,
ignored_index,
total_classes,
class_start_idx,
n_cols, # shapes
n_rows,
logits.stride(0), # strides
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
SPLIT=split,
)
if split:
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if n_splits > 1:
lse = torch.logsumexp(lse, dim=0)
losses = losses.sum(dim=0)
if world_size > 1:
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
lse = torch.logsumexp(lse_allgather, dim=0)
handle_losses.wait()
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's smoothing=0.1, the total losses are
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
# Again, we just have to add the (global) lse.
losses += lse
if lse_square_scale != 0.0:
z_losses = lse_square_scale * lse.square()
z_losses.masked_fill_(labels == ignored_index, 0.0)
losses += z_losses
else:
z_losses = torch.zeros_like(losses)
losses.masked_fill_(labels == ignored_index, 0.0)
ctx.save_for_backward(logits, lse, labels)
ctx.mark_non_differentiable(z_losses)
ctx.smoothing = smoothing
ctx.logit_scale = logit_scale
ctx.lse_square_scale = lse_square_scale
ctx.ignored_index = ignored_index
ctx.total_classes = total_classes
ctx.class_start_idx = class_start_idx
ctx.inplace_backward = inplace_backward
return losses, z_losses
@staticmethod
def backward(ctx, grad_losses, grad_z_losses):
del grad_z_losses # z_losses are only for logging.
logits, lse, labels = ctx.saved_tensors
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
n_rows, n_cols = logits.shape
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_bwd_kernel[grid](
dlogits, # data ptrs
grad_losses,
logits,
lse,
labels,
ctx.smoothing,
ctx.logit_scale,
ctx.lse_square_scale,
ctx.ignored_index,
ctx.total_classes,
ctx.class_start_idx,
n_cols, # shapes
logits.stride(0), # strides
dlogits.stride(0),
grad_losses.stride(0),
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
)
return dlogits, None, None, None, None, None, None, None, None
def cross_entropy_loss(
logits: torch.Tensor,
labels: torch.Tensor,
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
ignored_index=-100,
inplace_backward: bool = False,
process_group=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
Returns:
losses: (batch,), float
z_losses: (batch,), float
"""
return CrossEntropyLossFunction.apply(
logits,
labels,
label_smoothing,
logit_scale,
lse_square_scale,
ignored_index,
inplace_backward,
process_group,
)
class FusedCrossEntropyLoss(nn.Module):
def __init__(
self,
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
inplace_backward=False,
process_group=None,
return_z_loss=False,
):
"""
Arguments:
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
label_smoothing: float
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
return_z_loss: bool. If True, we return the component of the loss contributed by
the lse_square_scale value. This value is only for logging and does not support
backprop.
"""
super().__init__()
if reduction not in ["mean", "none", "sum"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.logit_scale = logit_scale
self.lse_square_scale = lse_square_scale
self.inplace_backward = inplace_backward
self.process_group = process_group
self.return_z_loss = return_z_loss
def forward(self, input, target):
"""
Arguments:
input: (batch, vocab_size)
target: (batch,)
Returns:
losses: (batch,) if reduction is 'none', else (1,), dtype float
z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
"""
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
loss, z_loss = cross_entropy_loss(
input,
target,
label_smoothing=self.label_smoothing,
logit_scale=self.logit_scale,
lse_square_scale=self.lse_square_scale,
ignored_index=self.ignore_index,
inplace_backward=self.inplace_backward,
process_group=self.process_group,
)
if self.reduction == "mean":
loss = loss.sum() / (target != self.ignore_index).sum()
elif self.reduction == "sum":
loss = loss.sum()
else:
loss = loss
if not self.return_z_loss:
return loss
if self.reduction == "mean":
z_loss = z_loss.sum() / (target != self.ignore_index).sum()
elif self.reduction == "sum":
z_loss = z_loss.sum()
else:
z_loss = z_loss
return loss, z_loss

View File

@@ -0,0 +1,889 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
# Implement residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from fla.utils import contiguous
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
dtype
)
return out if not prenorm else (out, x)
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + \
bias if bias is not None else (x * rstd * weight)
out = out.to(dtype)
return out if not prenorm else (out, x)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
O, # pointer to the gate
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
RESIDUAL_OUT, # pointer to the residual
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
N, # number of columns in X
eps, # epsilon to avoid division by zero
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
O += row * stride_x_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols <
N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w if HAS_WEIGHT else x_hat
if HAS_BIAS:
y = y + b
# Swish output gate
o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)
y = y * o * tl.sigmoid(o)
# Write output
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(
x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
):
if residual is not None:
residual_dtype = residual.dtype
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
assert residual_out.stride(-1) == 1
else:
residual_out = None
mean = torch.empty((M,), dtype=torch.float32,
device="cuda") if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)](
x,
o,
y,
weight,
bias,
residual,
residual_out,
mean,
rstd,
x.stride(0),
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
N,
eps,
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return y, mean, rstd, residual_out if residual_out is not None else x
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
X, # pointer to the input
O, # pointer to the gate
W, # pointer to the weights
B, # pointer to the biases
Y, # pointer to the output to be recomputed
DY, # pointer to the output gradient
DX, # pointer to the input gradient
DO, # pointer to the gate gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
DRESIDUAL,
DRESIDUAL_IN,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_dy_row,
stride_dx_row,
stride_dres_row,
stride_dres_in_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
rows_per_program,
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_DRESIDUAL: tl.constexpr,
STORE_DRESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, BLOCK_N)
mask = cols < N
X += row_start * stride_x_row
O += row_start * stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += row_start * stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += row_start * stride_dres_in_row
DY += row_start * stride_dy_row
DX += row_start * stride_dx_row
DO += row_start * stride_dx_row
if RECOMPUTE_OUTPUT:
Y += row_start * stride_y_row
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
if RECOMPUTE_OUTPUT and HAS_BIAS:
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
if HAS_BIAS:
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
row_end = min((row_block_id + 1) * rows_per_program, M)
for row in range(row_start, row_end):
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
o = tl.load(O + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
xhat = tl.where(mask, xhat, 0.0)
y = xhat * w if HAS_WEIGHT else xhat
if HAS_BIAS:
y = y + b
if RECOMPUTE_OUTPUT:
tl.store(Y + cols, y, mask=mask)
sigmoid_o = tl.sigmoid(o)
do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o))
dy = dy * o * sigmoid_o
wdy = dy
if HAS_WEIGHT:
wdy = dy * w
dw += dy * xhat
if HAS_BIAS:
db += dy
if not IS_RMS_NORM:
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
else:
c1 = tl.sum(xhat * wdy, axis=0) / N
dx = (wdy - xhat * c1) * rstd
if HAS_DRESIDUAL:
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
dx += dres
# Write dx
if STORE_DRESIDUAL:
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
tl.store(DX + cols, dx, mask=mask)
tl.store(DO + cols, do, mask=mask)
X += stride_x_row
O += stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += stride_dres_in_row
if RECOMPUTE_OUTPUT:
Y += stride_y_row
DY += stride_dy_row
DX += stride_dx_row
DO += stride_dx_row
if HAS_WEIGHT:
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
if HAS_BIAS:
tl.store(DB + row_block_id * N + cols, db, mask=mask)
def _layer_norm_bwd(
dy,
x,
o,
weight,
bias,
eps,
mean,
rstd,
dresidual=None,
has_residual=False,
is_rms_norm=False,
x_dtype=None,
recompute_output=False,
):
M, N = x.shape
assert x.stride(-1) == 1
assert dy.stride(-1) == 1
assert dy.shape == (M, N)
if dresidual is not None:
assert dresidual.stride(-1) == 1
assert dresidual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
dx = (
torch.empty_like(x)
if x_dtype is None
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
do = (
torch.empty_like(o)
if x_dtype is None
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = (
torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
if weight is not None
else None
)
_db = (
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
if bias is not None
else None
)
rows_per_program = math.ceil(M / sm_count)
grid = (sm_count,)
with torch.cuda.device(x.device.index):
_layer_norm_bwd_kernel[grid](
x,
o,
weight,
bias,
y,
dy,
dx,
do,
_dw,
_db,
dresidual,
dresidual_in,
mean,
rstd,
x.stride(0),
0 if not recompute_output else y.stride(0),
dy.stride(0),
dx.stride(0),
dresidual.stride(0) if dresidual is not None else 0,
dresidual_in.stride(0) if dresidual_in is not None else 0,
M,
N,
eps,
rows_per_program,
is_rms_norm,
BLOCK_N,
dresidual is not None,
dresidual_in is not None,
weight is not None,
bias is not None,
)
dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
db = _db.sum(0).to(bias.dtype) if bias is not None else None
# Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype:
dresidual_in = dx
return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y)
class LayerNormSwishGateFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
o,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
o_shape_og = o.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
o = o.reshape(-1, o.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x, o, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
)
ctx.save_for_backward(residual_out, o, weight, bias, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.o_shape_og = o_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
y = y.reshape(x_shape_og)
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dy, *args):
x, o, weight, bias, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1])
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, do, dw, db, dresidual_in = _layer_norm_bwd(
dy,
x,
o,
weight,
bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
)
return (
dx.reshape(ctx.x_shape_og),
do.reshape(ctx.o_shape_og),
dw,
db,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
class LayerNormSwishGateLinearFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
o_shape_og = o.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
o = o.reshape(-1, o.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x,
o,
norm_weight,
norm_bias,
eps,
residual,
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm
)
y = y.reshape(x_shape_og)
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
linear_weight = linear_weight.to(dtype)
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
# We don't store y, will be recomputed in the backward pass to save memory
ctx.save_for_backward(residual_out, o, norm_weight, norm_bias, linear_weight, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.o_shape_og = o_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
ctx.linear_bias_is_none = linear_bias is None
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dout, *args):
x, o, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dy = F.linear(dout, linear_weight.t())
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, do, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dy,
x,
o,
norm_weight,
norm_bias,
ctx.eps,
mean,
rstd,
dresidual=dresidual,
has_residual=ctx.has_residual,
is_rms_norm=ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
return (
dx.reshape(ctx.x_shape_og),
do.reshape(ctx.o_shape_og),
dnorm_weight,
dnorm_bias,
dlinear_weight,
dlinear_bias,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
def layer_norm_swish_gate_fn(
x,
o,
weight,
bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateFn.apply(
x,
o,
weight,
bias,
residual,
eps,
prenorm,
residual_in_fp32,
False
)
def rms_norm_swish_gate_fn(
x,
o,
weight,
bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateFn.apply(
x,
o,
weight,
bias,
residual,
eps,
prenorm,
residual_in_fp32,
True
)
def layer_norm_swish_gate_linear_fn(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateLinearFn.apply(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual,
eps,
prenorm,
residual_in_fp32,
False
)
def rms_norm_swish_gate_linear_fn(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormSwishGateLinearFn.apply(
x,
o,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual,
eps,
prenorm,
residual_in_fp32,
True
)
class FusedLayerNormSwishGate(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedLayerNormSwishGate:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_swish_gate_fn(
x,
o,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class FusedRMSNormSwishGate(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedRMSNormSwishGate:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_swish_gate_fn(
x,
o,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class FusedLayerNormSwishGateLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedLayerNormSwishGateLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_swish_gate_linear_fn(
x,
o,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class FusedRMSNormSwishGateLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> FusedRMSNormSwishGateLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_swish_gate_linear_fn(
x,
o,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)

216
finetune/lora/v6/fla/modules/l2norm.py vendored Normal file
View File

@@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_fwd, custom_bwd
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _l2_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
stride_x_row, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_x_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0)
rstd = 1 / tl.sqrt(var + eps)
# tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
y = x * rstd
# Write output
tl.store(Y + cols, y, mask=mask)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _l2_norm_bwd_kernel(
X, # pointer to the input
# Y, # pointer to the output to be recomputed
DY, # pointer to the output gradient
DX, # pointer to the input gradient
stride_x_row, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
DX += row * stride_x_row
DY += row * stride_x_row
# Y += row * stride_y_row
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x, 0.0)
var = tl.sum(x * x)
rstd = 1 / tl.sqrt(var + eps)
# tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
# y = x * rstd
dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)
dy = tl.where(cols < N, dy, 0.0)
# dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
tl.store(DX + cols, dx, mask=mask)
def _l2_norm_fwd(
x, eps=1e-6
):
x_shape_og = x.shape
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
M, N = x.shape
assert x.stride(-1) == 1
# allocate output
y = torch.empty_like(x)
assert y.stride(-1) == 1
N = x.shape[-1]
M = x.shape[0]
# rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_l2_norm_fwd_1pass_kernel[(M,)](
x,
y,
x.stride(0),
N,
eps,
# is_rms_norm,
BLOCK_N,
# residual is not None,
# residual_out is not None,
# bias is not None,
)
return y.reshape(x_shape_og)
def _l2_norm_bwd(
x, dy, eps=1e-5,
):
x_shape_og = x.shape
x = x.reshape(-1, dy.shape[-1])
dy = dy.reshape(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()
assert dy.shape == x.shape
# allocate output
dx = torch.empty_like(x)
N = x.shape[-1]
M = x.shape[0]
assert x.stride(-1) == 1
assert dy.stride(-1) == 1
# rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_l2_norm_bwd_kernel[(M,)](
x,
dy,
dx,
x.stride(0),
N,
eps,
BLOCK_N,
)
return dx.reshape(x_shape_og)
class L2NormFN(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
eps=1e-6,
):
# reshape input data into 2D tensor
y = _l2_norm_fwd(x, eps)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.x_dtype = x.dtype
ctx.save_for_backward(x)
return y
@staticmethod
def backward(ctx, dy, *args):
x, = ctx.saved_tensors
dx = _l2_norm_bwd(
x,
dy,
ctx.eps,
)
return (
dx,
None
)
l2_norm_fn = L2NormFN.apply
if __name__ == '__main__':
x = torch.rand(10, 10, 100).cuda().requires_grad_(True)
y = torch.nn.functional.normalize(x, dim=-1, p=2)
dy = torch.rand_like(y)
y.backward(dy, retain_graph=True)
x_grad, x.grad = x.grad, None
y2 = l2_norm_fn(x, 1e-6)
print((y-y2).abs().max())
y2.backward(dy, retain_graph=True)
x_grad2, x.grad = x.grad, None
print((x_grad2-x_grad).abs().max())
breakpoint()

View File

@@ -0,0 +1,802 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
# Implement residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from fla.utils import contiguous
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
dtype
)
return out if not prenorm else (out, x)
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
dtype = x.dtype
if upcast:
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if residual is not None:
x = (x + residual).to(x.dtype)
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + \
bias if bias is not None else (x * rstd * weight)
out = out.to(dtype)
return out if not prenorm else (out, x)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
RESIDUAL_OUT, # pointer to the residual
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
N, # number of columns in X
eps, # epsilon to avoid division by zero
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols <
N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w if HAS_WEIGHT else x_hat
if HAS_BIAS:
y = y + b
# Write output
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
):
if residual is not None:
residual_dtype = residual.dtype
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
assert residual_out.stride(-1) == 1
else:
residual_out = None
mean = torch.empty((M,), dtype=torch.float32,
device="cuda") if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)](
x,
y,
weight,
bias,
residual,
residual_out,
mean,
rstd,
x.stride(0),
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
N,
eps,
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return y, mean, rstd, residual_out if residual_out is not None else x
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
X, # pointer to the input
W, # pointer to the weights
B, # pointer to the biases
Y, # pointer to the output to be recomputed
DY, # pointer to the output gradient
DX, # pointer to the input gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
DRESIDUAL,
DRESIDUAL_IN,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_dy_row,
stride_dx_row,
stride_dres_row,
stride_dres_in_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
rows_per_program,
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_DRESIDUAL: tl.constexpr,
STORE_DRESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, BLOCK_N)
mask = cols < N
X += row_start * stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += row_start * stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += row_start * stride_dres_in_row
DY += row_start * stride_dy_row
DX += row_start * stride_dx_row
if RECOMPUTE_OUTPUT:
Y += row_start * stride_y_row
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
if RECOMPUTE_OUTPUT and HAS_BIAS:
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
if HAS_BIAS:
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
row_end = min((row_block_id + 1) * rows_per_program, M)
for row in range(row_start, row_end):
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
xhat = tl.where(mask, xhat, 0.0)
if RECOMPUTE_OUTPUT:
y = xhat * w if HAS_WEIGHT else xhat
if HAS_BIAS:
y = y + b
tl.store(Y + cols, y, mask=mask)
wdy = dy
if HAS_WEIGHT:
wdy = dy * w
dw += dy * xhat
if HAS_BIAS:
db += dy
if not IS_RMS_NORM:
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
else:
c1 = tl.sum(xhat * wdy, axis=0) / N
dx = (wdy - xhat * c1) * rstd
if HAS_DRESIDUAL:
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
dx += dres
# Write dx
if STORE_DRESIDUAL:
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
tl.store(DX + cols, dx, mask=mask)
X += stride_x_row
if HAS_DRESIDUAL:
DRESIDUAL += stride_dres_row
if STORE_DRESIDUAL:
DRESIDUAL_IN += stride_dres_in_row
if RECOMPUTE_OUTPUT:
Y += stride_y_row
DY += stride_dy_row
DX += stride_dx_row
if HAS_WEIGHT:
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
if HAS_BIAS:
tl.store(DB + row_block_id * N + cols, db, mask=mask)
def _layer_norm_bwd(
dy,
x,
weight,
bias,
eps,
mean,
rstd,
dresidual=None,
has_residual=False,
is_rms_norm=False,
x_dtype=None,
recompute_output=False,
):
M, N = x.shape
assert x.stride(-1) == 1
assert dy.stride(-1) == 1
assert dy.shape == (M, N)
if dresidual is not None:
assert dresidual.stride(-1) == 1
assert dresidual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
dx = (
torch.empty_like(x)
if x_dtype is None
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
dresidual_in = torch.empty_like(
x) if has_residual and dx.dtype != x.dtype else None
y = torch.empty(M, N, dtype=dy.dtype,
device=dy.device) if recompute_output else None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = (
torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
if weight is not None
else None
)
_db = (
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
if bias is not None
else None
)
rows_per_program = math.ceil(M / sm_count)
grid = (sm_count,)
with torch.cuda.device(x.device.index):
_layer_norm_bwd_kernel[grid](
x,
weight,
bias,
y,
dy,
dx,
_dw,
_db,
dresidual,
dresidual_in,
mean,
rstd,
x.stride(0),
0 if not recompute_output else y.stride(0),
dy.stride(0),
dx.stride(0),
dresidual.stride(0) if dresidual is not None else 0,
dresidual_in.stride(0) if dresidual_in is not None else 0,
M,
N,
eps,
rows_per_program,
is_rms_norm,
BLOCK_N,
dresidual is not None,
dresidual_in is not None,
weight is not None,
bias is not None,
)
dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
db = _db.sum(0).to(bias.dtype) if bias is not None else None
# Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype:
dresidual_in = dx
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
class LayerNormFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
)
ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
y = y.reshape(x_shape_og)
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dy, *args):
x, weight, bias, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1])
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dw, db, dresidual_in = _layer_norm_bwd(
dy,
x,
weight,
bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
)
return (
dx.reshape(ctx.x_shape_og),
dw,
db,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
def layer_norm_fn(
x,
weight,
bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
def rms_norm_fn(
x,
weight,
bias,
residual=None,
prenorm=False,
residual_in_fp32=False,
eps=1e-6
):
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
class LayerNorm(nn.Module):
def __init__(
self,
hidden_size: int,
elementwise_affine: bool = True,
eps: float = 1e-5
) -> LayerNorm:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_fn(
x,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32
)
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
elementwise_affine: bool = True,
eps: float = 1e-5
) -> RMSNorm:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
return rms_norm_fn(
x,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
)
class LayerNormLinearFn(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x,
norm_weight,
norm_bias,
eps,
residual,
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm,
)
y = y.reshape(x_shape_og)
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
linear_weight = linear_weight.to(dtype)
linear_bias = linear_bias.to(
dtype) if linear_bias is not None else None
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
# We don't store y, will be recomputed in the backward pass to save memory
ctx.save_for_backward(residual_out, norm_weight,
norm_bias, linear_weight, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
ctx.linear_bias_is_none = linear_bias is None
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
@staticmethod
@contiguous
def backward(ctx, dout, *args):
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
dy = F.linear(dout, linear_weight.t())
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dy,
x,
norm_weight,
norm_bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
return (
dx.reshape(ctx.x_shape_og),
dnorm_weight,
dnorm_bias,
dlinear_weight,
dlinear_bias,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)
def layer_norm_linear_fn(
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
return LayerNormLinearFn.apply(
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual,
eps,
prenorm,
residual_in_fp32,
is_rms_norm,
)
class LayerNormLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> LayerNormLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_linear_fn(
x,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=False
)
class RMSNormLinear(nn.Module):
def __init__(
self,
hidden_size,
elementwise_affine: bool = True,
eps=1e-5
) -> RMSNormLinear:
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(hidden_size))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.hidden_size}"
if not self.elementwise_affine:
s += f", elementwise_affine={self.elementwise_affine}"
s += f", eps={self.eps}"
s += ")"
return s
def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_linear_fn(
x,
self.weight,
self.bias,
weight,
bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=True
)

310
finetune/lora/v6/fla/modules/rotary.py vendored Normal file
View File

@@ -0,0 +1,310 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Tri Dao.
from typing import Optional, Tuple, Union
import torch
from einops import rearrange, repeat
from fla.ops.rotary import apply_rotary
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[x[..., :ro_dim] * cos +
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
out = apply_rotary(
x,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=interleaved,
inplace=inplace,
)
if isinstance(seqlen_offsets, int):
# Can't save int with save_for_backward
ctx.save_for_backward(cos, sin, cu_seqlens)
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
ctx.inplace = inplace
ctx.max_seqlen = max_seqlen
return out if not inplace else x
@staticmethod
def backward(ctx, do):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
else:
cos, sin, cu_seqlens = ctx.saved_tensors
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if not ctx.interleaved and not ctx.inplace:
do = do.clone()
dx = apply_rotary(
do,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=ctx.interleaved,
inplace=ctx.inplace,
conjugate=True,
)
return dx, None, None, None, None, None, None, None
def apply_rotary_emb(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmb.apply(
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
)
# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(
self,
dim: int,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
"""
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
scale = (
(torch.arange(0, dim, 2, device=device,
dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device,
dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype,
device=self.scale.device)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(
device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
seqlen_offset: Union[int, torch.Tensor] = 0,
max_seqlen: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
Apply rotary embedding *inplace* to qkv and / or kv.
"""
seqlen = q.shape[1]
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype)
elif isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=q.device, dtype=q.dtype)
if self.scale is None:
q = apply_rotary_emb_func(
q,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
k = apply_rotary_emb_func(
k,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
q = apply_rotary_emb_func(
q,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
k = apply_rotary_emb_func(
k,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
return q, k