This commit is contained in:
20
finetune/lora/v6/fla/modules/__init__.py
vendored
Normal file
20
finetune/lora/v6/fla/modules/__init__.py
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution,
|
||||
ShortConvolution)
|
||||
from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
|
||||
from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate,
|
||||
FusedLayerNormSwishGateLinear,
|
||||
FusedRMSNormSwishGate,
|
||||
FusedRMSNormSwishGateLinear)
|
||||
from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm,
|
||||
RMSNormLinear)
|
||||
from fla.modules.rotary import RotaryEmbedding
|
||||
|
||||
__all__ = [
|
||||
'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
|
||||
'FusedCrossEntropyLoss',
|
||||
'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
|
||||
'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
|
||||
'RotaryEmbedding'
|
||||
]
|
||||
394
finetune/lora/v6/fla/modules/activations.py
vendored
Normal file
394
finetune/lora/v6/fla/modules/activations.py
vendored
Normal file
@@ -0,0 +1,394 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023-2024, Tri Dao, Yu Zhang, Songlin Yang.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
sigmoid_fwd_codestring = """
|
||||
template <typename T> T sigmoid_fwd(T x) {
|
||||
return 1.0f / (1.0f + ::exp(-float(x)));
|
||||
}
|
||||
"""
|
||||
sigmoid_bwd_codestring = """
|
||||
template <typename T> T sigmoid_bwd(T x, T g) {
|
||||
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
||||
return float(g) * x_sigmoid * (1.0f - x_sigmoid);
|
||||
}
|
||||
"""
|
||||
|
||||
sigmoid_fwd = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring)
|
||||
sigmoid_bwd = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring)
|
||||
|
||||
|
||||
class SigmoidFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return sigmoid_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
x, = ctx.saved_tensors
|
||||
return sigmoid_bwd(x, dout)
|
||||
|
||||
|
||||
sigmoid = SigmoidFunction.apply
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BT': 16}, num_warps=2),
|
||||
triton.Config({'BT': 16}, num_warps=4),
|
||||
triton.Config({'BT': 16}, num_warps=8),
|
||||
triton.Config({'BT': 32}, num_warps=2),
|
||||
triton.Config({'BT': 32}, num_warps=4),
|
||||
triton.Config({'BT': 32}, num_warps=8),
|
||||
triton.Config({'BT': 64}, num_warps=2),
|
||||
triton.Config({'BT': 64}, num_warps=4),
|
||||
triton.Config({'BT': 64}, num_warps=8),
|
||||
triton.Config({'BT': 128}, num_warps=2),
|
||||
triton.Config({'BT': 128}, num_warps=4),
|
||||
triton.Config({'BT': 128}, num_warps=8),
|
||||
triton.Config({'BT': 256}, num_warps=2),
|
||||
triton.Config({'BT': 256}, num_warps=4),
|
||||
triton.Config({'BT': 256}, num_warps=8)
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def logsigmoid_fwd_kernel(
|
||||
x,
|
||||
y,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr
|
||||
):
|
||||
i = tl.program_id(0)
|
||||
o_i = i * BT + tl.arange(0, BT)
|
||||
|
||||
p_x = x + o_i
|
||||
p_y = y + o_i
|
||||
mask = o_i < T
|
||||
|
||||
# [D,]
|
||||
b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)
|
||||
b_m = tl.minimum(0., b_x)
|
||||
b_z = 1. + tl.exp(-tl.abs(b_x))
|
||||
b_y = b_m - tl.log(b_z)
|
||||
tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BT': 16}, num_warps=2),
|
||||
triton.Config({'BT': 16}, num_warps=4),
|
||||
triton.Config({'BT': 16}, num_warps=8),
|
||||
triton.Config({'BT': 32}, num_warps=2),
|
||||
triton.Config({'BT': 32}, num_warps=4),
|
||||
triton.Config({'BT': 32}, num_warps=8),
|
||||
triton.Config({'BT': 64}, num_warps=2),
|
||||
triton.Config({'BT': 64}, num_warps=4),
|
||||
triton.Config({'BT': 64}, num_warps=8),
|
||||
triton.Config({'BT': 128}, num_warps=2),
|
||||
triton.Config({'BT': 128}, num_warps=4),
|
||||
triton.Config({'BT': 128}, num_warps=8),
|
||||
triton.Config({'BT': 256}, num_warps=2),
|
||||
triton.Config({'BT': 256}, num_warps=4),
|
||||
triton.Config({'BT': 256}, num_warps=8)
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def logsigmoid_bwd_kernel(
|
||||
x,
|
||||
dx,
|
||||
dy,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr
|
||||
):
|
||||
i = tl.program_id(0)
|
||||
o_i = i * BT + tl.arange(0, BT)
|
||||
|
||||
p_x = x + o_i
|
||||
p_dx = dx + o_i
|
||||
p_dy = dy + o_i
|
||||
mask = o_i < T
|
||||
|
||||
# [D,]
|
||||
b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32)
|
||||
b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32)
|
||||
b_dx = b_dy * (1. - tl.sigmoid(b_x))
|
||||
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
class LogSigmoidFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, x):
|
||||
T, D = x.numel(), x.shape[-1]
|
||||
y = torch.empty_like(x)
|
||||
logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D)
|
||||
ctx.save_for_backward(x,)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, dy):
|
||||
x, = ctx.saved_tensors
|
||||
T, D = x.numel(), x.shape[-1]
|
||||
dx = torch.empty_like(x)
|
||||
logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D)
|
||||
return dx
|
||||
|
||||
|
||||
logsigmoid = LogSigmoidFunction.apply
|
||||
|
||||
swish_fwd_codestring = """
|
||||
template <typename T> T swish_fwd(T x) {
|
||||
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
||||
return float(x) * x_sigmoid;
|
||||
}
|
||||
"""
|
||||
swish_bwd_codestring = """
|
||||
template <typename T> T swish_bwd(T x, T g) {
|
||||
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
||||
return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x));
|
||||
}
|
||||
"""
|
||||
|
||||
swish_fwd = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring)
|
||||
swish_bwd = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring)
|
||||
|
||||
|
||||
class SwishFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return swish_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
x, = ctx.saved_tensors
|
||||
return swish_bwd(x, dout)
|
||||
|
||||
|
||||
swish = SwishFunction.apply
|
||||
|
||||
# 1/sqrt(2*pi)-> 0.3989423
|
||||
# 1/sqrt(2) -> 0.70710678
|
||||
# sqrt(2/pi) -> 0.79788456
|
||||
|
||||
|
||||
# this function is tanh approximation of gelu
|
||||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
@torch.jit.script
|
||||
def bias_gelu(y, bias):
|
||||
x = bias + y
|
||||
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
|
||||
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_bwd(g, y, bias):
|
||||
"""Assume that y has shape (B, D) and bias has shape (D)"""
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
||||
1 + tanh_out
|
||||
)
|
||||
grad_y = ff * g
|
||||
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
|
||||
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(input, bias)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_bwd(grad_output, input, bias)
|
||||
return tmp, tmp
|
||||
|
||||
|
||||
bias_gelu_impl = GeLUFunction.apply
|
||||
|
||||
|
||||
# this function is tanh approximation of gelu
|
||||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
@torch.jit.script
|
||||
def gelu_fwd(x):
|
||||
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
|
||||
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def gelu_bwd(g, x):
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
||||
1 + tanh_out
|
||||
)
|
||||
return (ff * g).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class FastGeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return gelu_fwd(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(input,) = ctx.saved_tensors
|
||||
tmp = gelu_bwd(grad_output, input)
|
||||
return tmp
|
||||
|
||||
|
||||
fast_gelu_impl = FastGeLUFunction.apply
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def relu_bwd(g, x):
|
||||
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_fwd(x):
|
||||
r = F.relu(x)
|
||||
return (r * r).to(dtype=x.dtype)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_bwd(g, x):
|
||||
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class SquaredReLUFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return sqrelu_fwd(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return sqrelu_bwd(grad_output, input)
|
||||
|
||||
|
||||
sqrelu = SquaredReLUFunction.apply
|
||||
|
||||
|
||||
swiglu_fwd_codestring = """
|
||||
template <typename T> T swiglu_fwd(T x, T y) {
|
||||
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
|
||||
}
|
||||
"""
|
||||
swiglu_bwd_codestring = """
|
||||
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
|
||||
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
||||
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
|
||||
dy = float(x) * x_sigmoid * float(g);
|
||||
}
|
||||
"""
|
||||
|
||||
swiglu_bwd_with_output_codestring = """
|
||||
template <typename T> T swiglu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) {
|
||||
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
||||
float x_swish = float(x) * x_sigmoid;
|
||||
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
|
||||
dy = x_swish * float(g);
|
||||
z = x_swish * float(y);
|
||||
}
|
||||
"""
|
||||
|
||||
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
|
||||
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
|
||||
swiglu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_with_output_codestring, num_outputs=3)
|
||||
|
||||
|
||||
class SwiGLUFunction(torch.autograd.Function):
|
||||
r"""
|
||||
Swish-Gated Linear Unit (SwiGLU) function.
|
||||
|
||||
.. math::
|
||||
\text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return swiglu_fwd(x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
x, y = ctx.saved_tensors
|
||||
return swiglu_bwd(x, y, dout)
|
||||
|
||||
|
||||
class SwiGLULinearFunction(torch.autograd.Function):
|
||||
r"""
|
||||
Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation.
|
||||
|
||||
.. math::
|
||||
\text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b
|
||||
|
||||
This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, y, weight, bias):
|
||||
z = swiglu_fwd(x, y)
|
||||
out = F.linear(z.to(weight.dtype), weight, bias)
|
||||
# We don't store z, will be recomputed in the backward pass to save memory
|
||||
ctx.save_for_backward(x, y, weight)
|
||||
ctx.linear_bias_is_none = bias is None
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
x, y, weight = ctx.saved_tensors
|
||||
dout = dout.reshape(-1, dout.shape[-1])
|
||||
dz = F.linear(dout, weight.t()).view_as(x)
|
||||
dx, dy, z = swiglu_bwd_with_output(x, y, dz)
|
||||
dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1]))
|
||||
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
||||
return dx, dy, dlinear_weight, dlinear_bias
|
||||
|
||||
|
||||
swiglu = SwiGLUFunction.apply
|
||||
|
||||
swiglu_linear = SwiGLULinearFunction.apply
|
||||
|
||||
ACT2FN = {
|
||||
'relu': F.relu,
|
||||
'sigmoid': sigmoid,
|
||||
'logsigmoid': logsigmoid,
|
||||
'silu': swish,
|
||||
'swish': swish,
|
||||
'sqrelu': sqrelu,
|
||||
'gelu': fast_gelu_impl,
|
||||
'bias_gelu': bias_gelu_impl,
|
||||
}
|
||||
336
finetune/lora/v6/fla/modules/convolution.py
vendored
Normal file
336
finetune/lora/v6/fla/modules/convolution.py
vendored
Normal file
@@ -0,0 +1,336 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from fla.modules.activations import ACT2FN
|
||||
from fla.utils import checkpoint
|
||||
|
||||
try:
|
||||
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
except ImportError:
|
||||
causal_conv1d_fn = None
|
||||
causal_conv1d_update = None
|
||||
|
||||
|
||||
def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
|
||||
seqlen = u.shape[-1]
|
||||
fft_size = 2 * seqlen
|
||||
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
|
||||
if k_rev is not None:
|
||||
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
|
||||
k_f = k_f + k_rev_f.conj()
|
||||
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
|
||||
|
||||
if len(u.shape) > 3:
|
||||
k_f = k_f.unsqueeze(1)
|
||||
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
|
||||
|
||||
out = y + u
|
||||
if gelu:
|
||||
out = F.gelu(out)
|
||||
if dropout_mask is not None:
|
||||
return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
|
||||
else:
|
||||
return out.to(dtype=u.dtype)
|
||||
|
||||
|
||||
@checkpoint
|
||||
def proj_then_conv1d(
|
||||
x: torch.Tensor,
|
||||
proj_weight: torch.Tensor,
|
||||
conv1d_weight: torch.Tensor,
|
||||
conv1d_bias: Optional[torch.Tensor] = None,
|
||||
cache: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
# We do matmul and transpose BLH -> HBL at the same time
|
||||
x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2])
|
||||
|
||||
if causal_conv1d_fn is None:
|
||||
raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
|
||||
if cache is None:
|
||||
x = causal_conv1d_fn(
|
||||
x=x,
|
||||
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
|
||||
bias=conv1d_bias,
|
||||
activation="silu",
|
||||
).transpose(1, 2)
|
||||
else:
|
||||
assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
|
||||
x = x.squeeze(-1)
|
||||
x = causal_conv1d_update(
|
||||
x=x,
|
||||
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
|
||||
bias=conv1d_bias,
|
||||
cache=cache,
|
||||
activation="silu",
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class ShortConvolution(nn.Conv1d):
|
||||
"""
|
||||
Simple wrapper around `nn.Conv1d` that accepts dimension last.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
kernel_size: int,
|
||||
bias: bool = False,
|
||||
activation: Optional[str] = 'silu',
|
||||
use_causal_conv: Optional[bool] = True
|
||||
):
|
||||
super().__init__(in_channels=hidden_size,
|
||||
out_channels=hidden_size,
|
||||
kernel_size=kernel_size,
|
||||
groups=hidden_size,
|
||||
bias=bias,
|
||||
padding=kernel_size - 1)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.activation = None
|
||||
if activation is not None:
|
||||
assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
|
||||
self.activation = activation
|
||||
|
||||
if use_causal_conv:
|
||||
if causal_conv1d_fn is None:
|
||||
warnings.warn("Please install `causal-conv1d` to use causal convolutions, setting `use_causal_conv` to False.")
|
||||
use_causal_conv = False
|
||||
self.use_causal_conv = use_causal_conv
|
||||
|
||||
def extra_repr(self):
|
||||
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
|
||||
', stride={stride}')
|
||||
if self.padding != (0,) * len(self.padding):
|
||||
s += ', padding={padding}'
|
||||
if self.dilation != (1,) * len(self.dilation):
|
||||
s += ', dilation={dilation}'
|
||||
if self.output_padding != (0,) * len(self.output_padding):
|
||||
s += ', output_padding={output_padding}'
|
||||
if self.groups != 1:
|
||||
s += ', groups={groups}'
|
||||
if self.bias is None:
|
||||
s += ', bias=False'
|
||||
if self.padding_mode != 'zeros':
|
||||
s += ', padding_mode={padding_mode}'
|
||||
if self.activation is not None:
|
||||
s += ', activation={activation}'
|
||||
if not self.use_causal_conv:
|
||||
s += ', use_causal_conv={use_causal_conv}'
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
cache: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Tensor of shape `[batch_size, seq_len, hidden_size]`
|
||||
mask (`Optional[torch.Tensor]`):
|
||||
Attention mask dealing with padded positions.
|
||||
cache (`Optional[torch.Tensor]`):
|
||||
Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`,
|
||||
Returns:
|
||||
Tensor of shape `[batch_size, seq_len, hidden_size]`. The `cache` (if provided) is updated inplace.
|
||||
"""
|
||||
|
||||
if mask is not None:
|
||||
x = x.mul_(mask.unsqueeze(-1))
|
||||
if cache is not None and x.shape[1] == 1:
|
||||
return self.step(x, cache)
|
||||
x = rearrange(x, "b l d -> b d l")
|
||||
# Update state (B D W)
|
||||
if cache is not None:
|
||||
cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0)))
|
||||
if self.use_causal_conv:
|
||||
x = causal_conv1d_fn(
|
||||
x=x,
|
||||
weight=rearrange(self.weight, "d 1 w -> d w"),
|
||||
bias=self.bias,
|
||||
activation=self.activation,
|
||||
)
|
||||
else:
|
||||
x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
|
||||
if self.activation is not None:
|
||||
x = ACT2FN[self.activation](x)
|
||||
return rearrange(x, "b d l -> b l d")
|
||||
|
||||
def step(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cache: torch.Tensor
|
||||
):
|
||||
assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
||||
|
||||
x = x.squeeze(1)
|
||||
if self.use_causal_conv:
|
||||
x = causal_conv1d_update(
|
||||
x=x,
|
||||
conv_state=cache,
|
||||
weight=rearrange(self.weight, "d 1 w -> d w"),
|
||||
bias=self.bias,
|
||||
activation=self.activation,
|
||||
)
|
||||
else:
|
||||
dtype = x.dtype
|
||||
cache.copy_(torch.roll(cache, shifts=-1, dims=-1))
|
||||
cache[:, :, -1] = x
|
||||
x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
|
||||
if self.bias is not None:
|
||||
x = x + self.bias
|
||||
if self.activation is not None:
|
||||
x = ACT2FN[self.activation](x).to(dtype=dtype)
|
||||
return x.unsqueeze(1)
|
||||
|
||||
@property
|
||||
def state_size(self) -> int:
|
||||
return self.hidden_size * self.kernel_size
|
||||
|
||||
|
||||
class LongConvolution(nn.Module):
|
||||
"""
|
||||
LongConvolution applies a convolution operation on the input tensor using a fixed
|
||||
filter of length l_max.
|
||||
The filter is learned during training and is applied using FFT convolution.
|
||||
Args:
|
||||
hidden_size (int): The number of expected features in the input and output.
|
||||
l_max (int): The maximum sequence length.
|
||||
Returns:
|
||||
y: (b, l, d) tensor
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
l_max: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the LongConvolution module.
|
||||
Args:
|
||||
hidden_size (int): The number of expected features in the input and output.
|
||||
l_max (int): The maximum sequence length.
|
||||
"""
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.filter = nn.Parameter(torch.randn(self.hidden_size, l_max), requires_grad=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
"""
|
||||
Applies the LongConvolution operation on the input tensor.
|
||||
Args:
|
||||
x: (b, l, d) tensor
|
||||
Returns:
|
||||
y: (b, l, d) tensor
|
||||
"""
|
||||
x = x.transpose(1, 2)
|
||||
y = fft_conv(x, self.filter, dropout_mask=None, gelu=False)
|
||||
y = y.transpose(1, 2)
|
||||
return y.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, emb_dim: int, seq_len: int, **kwargs):
|
||||
"""Complex exponential positional embeddings for implicit long convolution filters."""
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
# The time embedding fed to the filteres is normalized so that t_f = 1
|
||||
t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
|
||||
|
||||
if emb_dim > 1:
|
||||
bands = (emb_dim - 1) // 2
|
||||
# To compute the right embeddings we use the "proper" linspace
|
||||
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
|
||||
w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
|
||||
|
||||
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
|
||||
z = torch.exp(-1j * f * w)
|
||||
z = torch.cat([t, z.real, z.imag], dim=-1)
|
||||
self.z = nn.Parameter(z, requires_grad=False)
|
||||
|
||||
def forward(self, L):
|
||||
return self.z[:, :L]
|
||||
|
||||
|
||||
class ImplicitLongConvolution(nn.Module):
|
||||
"""
|
||||
Long convolution with implicit filter parameterized by an MLP.
|
||||
|
||||
Args:
|
||||
hidden_size (int):
|
||||
The number of expected features in the input and output.
|
||||
l_max (int):
|
||||
The maximum sequence length.
|
||||
d_emb (Optional[int]):
|
||||
The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine).
|
||||
Defaults to 3.
|
||||
d_hidden (Optional[int]):
|
||||
The number of features in the hidden layer of the MLP. Defaults to 16.
|
||||
|
||||
Attributes:
|
||||
pos_emb (`PositionalEmbedding`): The positional embedding layer.
|
||||
mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
l_max: int,
|
||||
d_emb: int = 3,
|
||||
d_hidden: int = 16,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Long convolution with implicit filter parameterized by an MLP.
|
||||
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.d_emb = d_emb
|
||||
|
||||
assert (
|
||||
d_emb % 2 != 0 and d_emb >= 3
|
||||
), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)"
|
||||
self.pos_emb = PositionalEmbedding(d_emb, l_max)
|
||||
|
||||
# final linear layer
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(d_emb, d_hidden),
|
||||
torch.nn.ReLU(),
|
||||
nn.Linear(d_hidden, hidden_size),
|
||||
)
|
||||
|
||||
def filter(self, seq_len: int, *args, **kwargs):
|
||||
k = self.mlp(self.pos_emb(seq_len))
|
||||
|
||||
return k.transpose(1, 2)
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
x: (b, l, d) tensor
|
||||
Returns:
|
||||
y: (b, l, d) tensor
|
||||
"""
|
||||
x = x.transpose(1, 2)
|
||||
k = self.filter(x.shape[-1])
|
||||
y = fft_conv(x, k, dropout_mask=None, gelu=False)
|
||||
|
||||
y = y.transpose(1, 2)
|
||||
return y.to(dtype=x.dtype)
|
||||
235
finetune/lora/v6/fla/modules/feature_map.py
vendored
Normal file
235
finetune/lora/v6/fla/modules/feature_map.py
vendored
Normal file
@@ -0,0 +1,235 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fla.modules.layernorm import layer_norm_fn
|
||||
from fla.utils import checkpoint
|
||||
|
||||
|
||||
@checkpoint
|
||||
def flatten_diag_outer_product(x, y):
|
||||
z = torch.einsum("...i,...j->...ij", x, y)
|
||||
N = z.size(-1)
|
||||
indicies = torch.triu_indices(N, N)
|
||||
return z[..., indicies[0], indicies[1]]
|
||||
|
||||
|
||||
@checkpoint
|
||||
def flatten_diag_outer_product_off1(x, y):
|
||||
z = torch.einsum("...i,...j->...ij", x, y)
|
||||
N = z.size(-1)
|
||||
indicies = torch.triu_indices(N, N, 1)
|
||||
indices2 = torch.arange(0, N)
|
||||
return z[..., indicies[0], indicies[1]], z[..., indices2, indices2]
|
||||
|
||||
|
||||
def is_power_of_2(n):
|
||||
return (n & (n - 1) == 0) and n != 0
|
||||
|
||||
|
||||
class HedgehogFeatureMap(nn.Module):
|
||||
|
||||
r"""
|
||||
Hedgehog feature map as introduced in
|
||||
`The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry <https://arxiv.org/abs/2402.04347>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int
|
||||
) -> HedgehogFeatureMap:
|
||||
super().__init__()
|
||||
# Trainable map
|
||||
self.layer = nn.Linear(head_dim, head_dim)
|
||||
self.init_weights_()
|
||||
|
||||
def init_weights_(self):
|
||||
"""Initialize trainable map as identity"""
|
||||
with torch.no_grad():
|
||||
identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float)
|
||||
self.layer.weight.copy_(identity.to(self.layer.weight))
|
||||
nn.init.zeros_(self.layer.bias)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.layer(x) # shape b, h, l, d
|
||||
return torch.cat([2*x, -2*x], dim=-1).softmax(-1)
|
||||
|
||||
|
||||
class T2RFeatureMap(nn.Module):
|
||||
|
||||
r"""
|
||||
Simple linear mapping feature map as in
|
||||
`Finetuning Pretrained Transformers into RNNs <https://arxiv.org/abs/2103.13076>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
dot_dim: int = None
|
||||
) -> T2RFeatureMap:
|
||||
super().__init__()
|
||||
# Trainable map
|
||||
if dot_dim is None:
|
||||
dot_dim = head_dim
|
||||
self.layer = nn.Linear(head_dim, dot_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.layer(x).relu()
|
||||
|
||||
|
||||
class DPFPFeatureMap(nn.Module):
|
||||
|
||||
r"""
|
||||
Deterministic Parameter-Free Projection (DPFP) feature map in
|
||||
`Linear Transformers Are Secretly Fast Weight Programmers <https://arxiv.org/abs/2102.11174>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
nu: int = 4
|
||||
) -> DPFPFeatureMap:
|
||||
super().__init__()
|
||||
self.nu = nu
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = torch.cat([x.relu(), -x.relu()], dim=-1)
|
||||
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1)
|
||||
x_repeat = torch.cat([x] * self.nu, dim=-1)
|
||||
return x_repeat * x_rolled
|
||||
|
||||
|
||||
class HadamardFeatureMap(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int
|
||||
) -> HadamardFeatureMap:
|
||||
super().__init__()
|
||||
# Trainable map
|
||||
self.layer1 = nn.Linear(head_dim, head_dim)
|
||||
self.layer2 = nn.Linear(head_dim, head_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.layer1(x) * self.layer2(x)
|
||||
|
||||
|
||||
class LearnableOuterProductFeatureMap(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
feature_dim: int
|
||||
) -> LearnableOuterProductFeatureMap:
|
||||
super().__init__()
|
||||
# Trainable map
|
||||
self.layer1 = nn.Linear(head_dim, feature_dim, bias=False)
|
||||
self.layer2 = nn.Linear(head_dim, feature_dim, bias=False)
|
||||
self.normalizer = feature_dim ** -0.5
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return flatten_diag_outer_product(self.layer1(x), self.layer2(x))
|
||||
|
||||
|
||||
class LearnablePolySketchNonNegativeFeatureMap(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
sketch_size: Optional[int] = None,
|
||||
degree: Optional[int] = 2
|
||||
) -> LearnablePolySketchNonNegativeFeatureMap:
|
||||
super().__init__()
|
||||
|
||||
assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2"
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.sketch_size = sketch_size if sketch_size is not None else head_dim
|
||||
self.degree = degree
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(head_dim))
|
||||
self.beta = nn.Parameter(torch.zeros(head_dim))
|
||||
# NOTE: the sketch layers defined here are quite different from the original paper
|
||||
# currently we simply use linear layers without any non-linear activations
|
||||
self.sketches1 = nn.ModuleList([
|
||||
nn.Linear(head_dim, sketch_size, bias=False),
|
||||
*[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
|
||||
])
|
||||
self.sketches2 = nn.ModuleList([
|
||||
nn.Linear(head_dim, sketch_size, bias=False),
|
||||
*[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Section 2.1
|
||||
x = layer_norm_fn(x, self.gamma, self.beta)
|
||||
# first map the input to sketch size with learnable parameters
|
||||
x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5
|
||||
for i in range(1, int(math.log2(self.degree)) - 1):
|
||||
x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5
|
||||
# do sketch mapping for log2(p) - 1 times in total
|
||||
# do p=2 mapping to ensure non-negativity
|
||||
return flatten_diag_outer_product(x, x)
|
||||
|
||||
|
||||
class TaylorFeatureMap(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int
|
||||
) -> TaylorFeatureMap:
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.r2 = math.sqrt(2)
|
||||
self.rd = math.sqrt(self.head_dim)
|
||||
self.rrd = math.sqrt(self.rd)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
|
||||
return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1)
|
||||
|
||||
|
||||
class RebasedFeatureMap(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
use_gamma: Optional[bool] = True,
|
||||
use_beta: Optional[bool] = True,
|
||||
normalize: Optional[bool] = True
|
||||
) -> RebasedFeatureMap:
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.use_gamma = use_gamma
|
||||
self.use_beta = use_beta
|
||||
self.normalize = normalize
|
||||
|
||||
self.gamma = None
|
||||
self.beta = None
|
||||
if use_gamma:
|
||||
self.gamma = nn.Parameter(torch.ones(head_dim))
|
||||
if use_beta:
|
||||
self.beta = nn.Parameter(torch.zeros(head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor, flatten: Optional[bool] = True):
|
||||
if self.use_beta and self.use_gamma and self.normalize:
|
||||
x = layer_norm_fn(x, self.gamma, self.beta)
|
||||
elif self.normalize:
|
||||
x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta)
|
||||
elif self.use_gamma and self.use_beta:
|
||||
x = torch.addcmul(self.beta, x, self.gamma)
|
||||
elif self.use_gamma:
|
||||
x = x.mul(self.gamma)
|
||||
else:
|
||||
raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, "
|
||||
f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)")
|
||||
if not flatten:
|
||||
return x
|
||||
x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
|
||||
# rebased use learnable parameters to approximate any quadratic function
|
||||
return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1)
|
||||
398
finetune/lora/v6/fla/modules/fused_cross_entropy.py
vendored
Normal file
398
finetune/lora/v6/fla/modules/fused_cross_entropy.py
vendored
Normal file
@@ -0,0 +1,398 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||
# version of PyTorch. The following 2 lines are for backward compatibility with
|
||||
# older PyTorch.
|
||||
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def cross_entropy_fwd_kernel(
|
||||
loss_ptr, # data ptrs
|
||||
lse_ptr,
|
||||
z_loss_ptr,
|
||||
logits_ptr,
|
||||
labels_ptr,
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
total_classes,
|
||||
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
||||
n_cols, # shapes
|
||||
n_rows,
|
||||
logits_row_stride, # strides
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HAS_SMOOTHING: tl.constexpr,
|
||||
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
|
||||
SPLIT: tl.constexpr,
|
||||
):
|
||||
row_idx = tl.program_id(0)
|
||||
col_block_idx = tl.program_id(1)
|
||||
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
|
||||
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
label_idx = tl.load(labels_ptr + row_idx)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
||||
tl.float32
|
||||
) * logit_scale
|
||||
max_logits = tl.max(logits, 0)
|
||||
if HAS_SMOOTHING:
|
||||
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
|
||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)
|
||||
if label_idx == ignored_index:
|
||||
loss = 0.0
|
||||
z_loss = 0.0
|
||||
else:
|
||||
label_idx -= class_start_idx
|
||||
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
|
||||
n_cols, (col_block_idx + 1) * BLOCK_SIZE
|
||||
):
|
||||
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
|
||||
if HAS_SMOOTHING:
|
||||
loss = (
|
||||
(lse if not SPLIT else 0.0)
|
||||
- smoothing * sum_logits / total_classes
|
||||
- (1 - smoothing) * logits_label
|
||||
)
|
||||
else:
|
||||
loss = (lse if not SPLIT else 0.0) - logits_label
|
||||
else:
|
||||
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
|
||||
if HAS_SMOOTHING:
|
||||
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
|
||||
else:
|
||||
loss = 0.0
|
||||
if not SPLIT:
|
||||
z_loss = lse_square_scale * lse * lse
|
||||
loss += z_loss
|
||||
else:
|
||||
z_loss = 0.0
|
||||
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)
|
||||
if not SPLIT:
|
||||
tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def cross_entropy_bwd_kernel(
|
||||
dlogits_ptr, # data ptrs
|
||||
dloss_ptr,
|
||||
logits_ptr,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
total_classes,
|
||||
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
|
||||
n_cols, # shapes
|
||||
logits_row_stride, # strides
|
||||
dlogits_row_stride,
|
||||
dloss_row_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
HAS_SMOOTHING: tl.constexpr,
|
||||
):
|
||||
row_idx = tl.program_id(0)
|
||||
col_block_idx = tl.program_id(1)
|
||||
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
|
||||
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
|
||||
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
label_idx = tl.load(labels_ptr + row_idx)
|
||||
if label_idx != ignored_index:
|
||||
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
|
||||
else:
|
||||
dloss = 0.0
|
||||
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
|
||||
tl.float32
|
||||
) * logit_scale
|
||||
lse = tl.load(lse_ptr + row_idx)
|
||||
probs = tl.exp(logits - lse)
|
||||
probs += 2.0 * lse_square_scale * lse * probs
|
||||
label_idx -= class_start_idx
|
||||
if HAS_SMOOTHING:
|
||||
smooth_negative = smoothing / total_classes
|
||||
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
|
||||
else:
|
||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
|
||||
|
||||
|
||||
class CrossEntropyLossFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
logits,
|
||||
labels,
|
||||
smoothing=0.0,
|
||||
logit_scale=1.0,
|
||||
lse_square_scale=0.0,
|
||||
ignored_index=-100,
|
||||
inplace_backward=False,
|
||||
process_group=None,
|
||||
):
|
||||
n_rows, n_cols = logits.shape
|
||||
assert labels.shape == (n_rows,)
|
||||
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
|
||||
total_classes = world_size * n_cols
|
||||
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
|
||||
class_start_idx = rank * n_cols
|
||||
|
||||
if logits.stride(-1) != 1:
|
||||
logits = logits.contiguous()
|
||||
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py
|
||||
MAX_BLOCK_SIZE = 64 * 1024
|
||||
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
|
||||
num_warps = (
|
||||
4
|
||||
if BLOCK_SIZE < 2048
|
||||
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
|
||||
)
|
||||
# We may split the lse computation across multiple blocks, then do a reduction
|
||||
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k)
|
||||
# where having just one thread block processing more than 64k elements is slow.
|
||||
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE
|
||||
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||||
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)
|
||||
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
||||
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
||||
z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)
|
||||
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||
with torch.cuda.device(logits.device.index):
|
||||
cross_entropy_fwd_kernel[(n_rows, n_splits)](
|
||||
losses, # data ptrs
|
||||
lse,
|
||||
z_losses,
|
||||
logits,
|
||||
labels,
|
||||
smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
total_classes,
|
||||
class_start_idx,
|
||||
n_cols, # shapes
|
||||
n_rows,
|
||||
logits.stride(0), # strides
|
||||
BLOCK_SIZE=BLOCK_SIZE, # constants
|
||||
num_warps=num_warps,
|
||||
SPLIT=split,
|
||||
)
|
||||
|
||||
if split:
|
||||
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
|
||||
# - predicted logit, and 0 otherwise.
|
||||
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
|
||||
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
|
||||
# For labels not in the vocab of this partition, losses contains
|
||||
# -0.1 * sum logit / total_classes.
|
||||
if n_splits > 1:
|
||||
lse = torch.logsumexp(lse, dim=0)
|
||||
losses = losses.sum(dim=0)
|
||||
if world_size > 1:
|
||||
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
|
||||
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
|
||||
handle_losses = torch.distributed.all_reduce(
|
||||
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
||||
)
|
||||
lse = torch.logsumexp(lse_allgather, dim=0)
|
||||
handle_losses.wait()
|
||||
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
|
||||
# we just have to add the (global) lse.
|
||||
# If there's smoothing=0.1, the total losses are
|
||||
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
|
||||
# Again, we just have to add the (global) lse.
|
||||
losses += lse
|
||||
if lse_square_scale != 0.0:
|
||||
z_losses = lse_square_scale * lse.square()
|
||||
z_losses.masked_fill_(labels == ignored_index, 0.0)
|
||||
losses += z_losses
|
||||
else:
|
||||
z_losses = torch.zeros_like(losses)
|
||||
losses.masked_fill_(labels == ignored_index, 0.0)
|
||||
|
||||
ctx.save_for_backward(logits, lse, labels)
|
||||
ctx.mark_non_differentiable(z_losses)
|
||||
ctx.smoothing = smoothing
|
||||
ctx.logit_scale = logit_scale
|
||||
ctx.lse_square_scale = lse_square_scale
|
||||
ctx.ignored_index = ignored_index
|
||||
ctx.total_classes = total_classes
|
||||
ctx.class_start_idx = class_start_idx
|
||||
ctx.inplace_backward = inplace_backward
|
||||
|
||||
return losses, z_losses
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_losses, grad_z_losses):
|
||||
del grad_z_losses # z_losses are only for logging.
|
||||
|
||||
logits, lse, labels = ctx.saved_tensors
|
||||
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
|
||||
n_rows, n_cols = logits.shape
|
||||
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
|
||||
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
|
||||
def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
|
||||
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||
with torch.cuda.device(logits.device.index):
|
||||
cross_entropy_bwd_kernel[grid](
|
||||
dlogits, # data ptrs
|
||||
grad_losses,
|
||||
logits,
|
||||
lse,
|
||||
labels,
|
||||
ctx.smoothing,
|
||||
ctx.logit_scale,
|
||||
ctx.lse_square_scale,
|
||||
ctx.ignored_index,
|
||||
ctx.total_classes,
|
||||
ctx.class_start_idx,
|
||||
n_cols, # shapes
|
||||
logits.stride(0), # strides
|
||||
dlogits.stride(0),
|
||||
grad_losses.stride(0),
|
||||
BLOCK_SIZE=BLOCK_SIZE, # constants
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return dlogits, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def cross_entropy_loss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
label_smoothing: float = 0.0,
|
||||
logit_scale: float = 1.0,
|
||||
lse_square_scale: float = 0.0,
|
||||
ignored_index=-100,
|
||||
inplace_backward: bool = False,
|
||||
process_group=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
logits: (batch, vocab_size)
|
||||
labels: (batch,)
|
||||
label_smoothing: float
|
||||
logit_scale: float. Multiply logits by this scale before calculating the loss.
|
||||
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
||||
This is also referred to as "z-loss".
|
||||
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
|
||||
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
||||
This saves memory.
|
||||
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
||||
one part of the vocab. The loss will be aggregated across processes.
|
||||
Returns:
|
||||
losses: (batch,), float
|
||||
z_losses: (batch,), float
|
||||
"""
|
||||
return CrossEntropyLossFunction.apply(
|
||||
logits,
|
||||
labels,
|
||||
label_smoothing,
|
||||
logit_scale,
|
||||
lse_square_scale,
|
||||
ignored_index,
|
||||
inplace_backward,
|
||||
process_group,
|
||||
)
|
||||
|
||||
|
||||
class FusedCrossEntropyLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ignore_index=-100,
|
||||
reduction="mean",
|
||||
label_smoothing=0.0,
|
||||
logit_scale=1.0,
|
||||
lse_square_scale=0.0,
|
||||
inplace_backward=False,
|
||||
process_group=None,
|
||||
return_z_loss=False,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
|
||||
label_smoothing: float
|
||||
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
|
||||
This is also referred to as "z-loss".
|
||||
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
|
||||
This saves memory.
|
||||
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
|
||||
one part of the vocab. The loss will be aggregated across processes.
|
||||
return_z_loss: bool. If True, we return the component of the loss contributed by
|
||||
the lse_square_scale value. This value is only for logging and does not support
|
||||
backprop.
|
||||
"""
|
||||
super().__init__()
|
||||
if reduction not in ["mean", "none", "sum"]:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
self.label_smoothing = label_smoothing
|
||||
self.logit_scale = logit_scale
|
||||
self.lse_square_scale = lse_square_scale
|
||||
self.inplace_backward = inplace_backward
|
||||
self.process_group = process_group
|
||||
self.return_z_loss = return_z_loss
|
||||
|
||||
def forward(self, input, target):
|
||||
"""
|
||||
Arguments:
|
||||
input: (batch, vocab_size)
|
||||
target: (batch,)
|
||||
Returns:
|
||||
losses: (batch,) if reduction is 'none', else (1,), dtype float
|
||||
z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
|
||||
"""
|
||||
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
|
||||
loss, z_loss = cross_entropy_loss(
|
||||
input,
|
||||
target,
|
||||
label_smoothing=self.label_smoothing,
|
||||
logit_scale=self.logit_scale,
|
||||
lse_square_scale=self.lse_square_scale,
|
||||
ignored_index=self.ignore_index,
|
||||
inplace_backward=self.inplace_backward,
|
||||
process_group=self.process_group,
|
||||
)
|
||||
if self.reduction == "mean":
|
||||
loss = loss.sum() / (target != self.ignore_index).sum()
|
||||
elif self.reduction == "sum":
|
||||
loss = loss.sum()
|
||||
else:
|
||||
loss = loss
|
||||
|
||||
if not self.return_z_loss:
|
||||
return loss
|
||||
|
||||
if self.reduction == "mean":
|
||||
z_loss = z_loss.sum() / (target != self.ignore_index).sum()
|
||||
elif self.reduction == "sum":
|
||||
z_loss = z_loss.sum()
|
||||
else:
|
||||
z_loss = z_loss
|
||||
|
||||
return loss, z_loss
|
||||
889
finetune/lora/v6/fla/modules/fused_norm_gate.py
vendored
Normal file
889
finetune/lora/v6/fla/modules/fused_norm_gate.py
vendored
Normal file
@@ -0,0 +1,889 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
|
||||
# Implement residual + layer_norm / rms_norm.
|
||||
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
||||
dtype = x.dtype
|
||||
if upcast:
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
residual = residual.float() if residual is not None else residual
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
||||
dtype
|
||||
)
|
||||
return out if not prenorm else (out, x)
|
||||
|
||||
|
||||
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
||||
dtype = x.dtype
|
||||
if upcast:
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
residual = residual.float() if residual is not None else residual
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + \
|
||||
bias if bias is not None else (x * rstd * weight)
|
||||
out = out.to(dtype)
|
||||
return out if not prenorm else (out, x)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
||||
)
|
||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
O, # pointer to the gate
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
RESIDUAL, # pointer to the residual
|
||||
RESIDUAL_OUT, # pointer to the residual
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_res_row,
|
||||
stride_res_out_row,
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_RESIDUAL: tl.constexpr,
|
||||
STORE_RESIDUAL_OUT: tl.constexpr,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
X += row * stride_x_row
|
||||
Y += row * stride_y_row
|
||||
O += row * stride_x_row
|
||||
if HAS_RESIDUAL:
|
||||
RESIDUAL += row * stride_res_row
|
||||
if STORE_RESIDUAL_OUT:
|
||||
RESIDUAL_OUT += row * stride_res_out_row
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_RESIDUAL:
|
||||
residual = tl.load(RESIDUAL + cols, mask=cols <
|
||||
N, other=0.0).to(tl.float32)
|
||||
x += residual
|
||||
if STORE_RESIDUAL_OUT:
|
||||
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
if HAS_WEIGHT:
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w if HAS_WEIGHT else x_hat
|
||||
if HAS_BIAS:
|
||||
y = y + b
|
||||
|
||||
# Swish output gate
|
||||
o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
y = y * o * tl.sigmoid(o)
|
||||
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
|
||||
):
|
||||
if residual is not None:
|
||||
residual_dtype = residual.dtype
|
||||
M, N = x.shape
|
||||
assert x.stride(-1) == 1
|
||||
if residual is not None:
|
||||
assert residual.stride(-1) == 1
|
||||
assert residual.shape == (M, N)
|
||||
if weight is not None:
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
||||
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
|
||||
assert residual_out.stride(-1) == 1
|
||||
else:
|
||||
residual_out = None
|
||||
mean = torch.empty((M,), dtype=torch.float32,
|
||||
device="cuda") if not is_rms_norm else None
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[(M,)](
|
||||
x,
|
||||
o,
|
||||
y,
|
||||
weight,
|
||||
bias,
|
||||
residual,
|
||||
residual_out,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
y.stride(0),
|
||||
residual.stride(0) if residual is not None else 0,
|
||||
residual_out.stride(0) if residual_out is not None else 0,
|
||||
N,
|
||||
eps,
|
||||
is_rms_norm,
|
||||
BLOCK_N,
|
||||
residual is not None,
|
||||
residual_out is not None,
|
||||
weight is not None,
|
||||
bias is not None,
|
||||
)
|
||||
# residual_out is None if residual is None and residual_dtype == input_dtype
|
||||
return y, mean, rstd, residual_out if residual_out is not None else x
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
||||
)
|
||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
||||
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
||||
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_kernel(
|
||||
X, # pointer to the input
|
||||
O, # pointer to the gate
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Y, # pointer to the output to be recomputed
|
||||
DY, # pointer to the output gradient
|
||||
DX, # pointer to the input gradient
|
||||
DO, # pointer to the gate gradient
|
||||
DW, # pointer to the partial sum of weights gradient
|
||||
DB, # pointer to the partial sum of biases gradient
|
||||
DRESIDUAL,
|
||||
DRESIDUAL_IN,
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_dy_row,
|
||||
stride_dx_row,
|
||||
stride_dres_row,
|
||||
stride_dres_in_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
rows_per_program,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_DRESIDUAL: tl.constexpr,
|
||||
STORE_DRESIDUAL: tl.constexpr,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
RECOMPUTE_OUTPUT: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the elements of X, DX, and DY it should compute.
|
||||
row_block_id = tl.program_id(0)
|
||||
row_start = row_block_id * rows_per_program
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
mask = cols < N
|
||||
X += row_start * stride_x_row
|
||||
O += row_start * stride_x_row
|
||||
if HAS_DRESIDUAL:
|
||||
DRESIDUAL += row_start * stride_dres_row
|
||||
if STORE_DRESIDUAL:
|
||||
DRESIDUAL_IN += row_start * stride_dres_in_row
|
||||
DY += row_start * stride_dy_row
|
||||
DX += row_start * stride_dx_row
|
||||
DO += row_start * stride_dx_row
|
||||
if RECOMPUTE_OUTPUT:
|
||||
Y += row_start * stride_y_row
|
||||
if HAS_WEIGHT:
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
row_end = min((row_block_id + 1) * rows_per_program, M)
|
||||
for row in range(row_start, row_end):
|
||||
# Load data to SRAM
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
o = tl.load(O + cols, mask=mask, other=0).to(tl.float32)
|
||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
||||
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.load(Mean + row)
|
||||
rstd = tl.load(Rstd + row)
|
||||
# Compute dx
|
||||
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
xhat = tl.where(mask, xhat, 0.0)
|
||||
|
||||
y = xhat * w if HAS_WEIGHT else xhat
|
||||
if HAS_BIAS:
|
||||
y = y + b
|
||||
if RECOMPUTE_OUTPUT:
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
sigmoid_o = tl.sigmoid(o)
|
||||
do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o))
|
||||
dy = dy * o * sigmoid_o
|
||||
wdy = dy
|
||||
if HAS_WEIGHT:
|
||||
wdy = dy * w
|
||||
dw += dy * xhat
|
||||
if HAS_BIAS:
|
||||
db += dy
|
||||
if not IS_RMS_NORM:
|
||||
c1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
c2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat * c1 + c2)) * rstd
|
||||
else:
|
||||
c1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
dx = (wdy - xhat * c1) * rstd
|
||||
if HAS_DRESIDUAL:
|
||||
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
||||
dx += dres
|
||||
# Write dx
|
||||
if STORE_DRESIDUAL:
|
||||
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
tl.store(DO + cols, do, mask=mask)
|
||||
|
||||
X += stride_x_row
|
||||
O += stride_x_row
|
||||
if HAS_DRESIDUAL:
|
||||
DRESIDUAL += stride_dres_row
|
||||
if STORE_DRESIDUAL:
|
||||
DRESIDUAL_IN += stride_dres_in_row
|
||||
if RECOMPUTE_OUTPUT:
|
||||
Y += stride_y_row
|
||||
DY += stride_dy_row
|
||||
DX += stride_dx_row
|
||||
DO += stride_dx_row
|
||||
if HAS_WEIGHT:
|
||||
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
||||
if HAS_BIAS:
|
||||
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_bwd(
|
||||
dy,
|
||||
x,
|
||||
o,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
mean,
|
||||
rstd,
|
||||
dresidual=None,
|
||||
has_residual=False,
|
||||
is_rms_norm=False,
|
||||
x_dtype=None,
|
||||
recompute_output=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
assert x.stride(-1) == 1
|
||||
assert dy.stride(-1) == 1
|
||||
assert dy.shape == (M, N)
|
||||
if dresidual is not None:
|
||||
assert dresidual.stride(-1) == 1
|
||||
assert dresidual.shape == (M, N)
|
||||
if weight is not None:
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
dx = (
|
||||
torch.empty_like(x)
|
||||
if x_dtype is None
|
||||
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
||||
)
|
||||
do = (
|
||||
torch.empty_like(o)
|
||||
if x_dtype is None
|
||||
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
||||
)
|
||||
dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
|
||||
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
||||
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_N:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
||||
_dw = (
|
||||
torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
||||
if weight is not None
|
||||
else None
|
||||
)
|
||||
_db = (
|
||||
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
||||
if bias is not None
|
||||
else None
|
||||
)
|
||||
rows_per_program = math.ceil(M / sm_count)
|
||||
grid = (sm_count,)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_bwd_kernel[grid](
|
||||
x,
|
||||
o,
|
||||
weight,
|
||||
bias,
|
||||
y,
|
||||
dy,
|
||||
dx,
|
||||
do,
|
||||
_dw,
|
||||
_db,
|
||||
dresidual,
|
||||
dresidual_in,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
0 if not recompute_output else y.stride(0),
|
||||
dy.stride(0),
|
||||
dx.stride(0),
|
||||
dresidual.stride(0) if dresidual is not None else 0,
|
||||
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
||||
M,
|
||||
N,
|
||||
eps,
|
||||
rows_per_program,
|
||||
is_rms_norm,
|
||||
BLOCK_N,
|
||||
dresidual is not None,
|
||||
dresidual_in is not None,
|
||||
weight is not None,
|
||||
bias is not None,
|
||||
)
|
||||
dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
|
||||
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
||||
# Don't need to compute dresidual_in separately in this case
|
||||
if has_residual and dx.dtype == x.dtype:
|
||||
dresidual_in = dx
|
||||
return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y)
|
||||
|
||||
|
||||
class LayerNormSwishGateFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
o,
|
||||
weight,
|
||||
bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
o_shape_og = o.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
o = o.reshape(-1, o.shape[-1])
|
||||
if residual is not None:
|
||||
assert residual.shape == x_shape_og
|
||||
residual = residual.reshape(-1, residual.shape[-1])
|
||||
residual_dtype = (
|
||||
residual.dtype
|
||||
if residual is not None
|
||||
else (torch.float32 if residual_in_fp32 else None)
|
||||
)
|
||||
y, mean, rstd, residual_out = _layer_norm_fwd(
|
||||
x, o, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
|
||||
)
|
||||
ctx.save_for_backward(residual_out, o, weight, bias, mean, rstd)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.o_shape_og = o_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.prenorm = prenorm
|
||||
ctx.x_dtype = x.dtype
|
||||
y = y.reshape(x_shape_og)
|
||||
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, dy, *args):
|
||||
x, o, weight, bias, mean, rstd = ctx.saved_tensors
|
||||
dy = dy.reshape(-1, dy.shape[-1])
|
||||
assert dy.shape == x.shape
|
||||
if ctx.prenorm:
|
||||
dresidual = args[0]
|
||||
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
||||
assert dresidual.shape == x.shape
|
||||
else:
|
||||
dresidual = None
|
||||
dx, do, dw, db, dresidual_in = _layer_norm_bwd(
|
||||
dy,
|
||||
x,
|
||||
o,
|
||||
weight,
|
||||
bias,
|
||||
ctx.eps,
|
||||
mean,
|
||||
rstd,
|
||||
dresidual,
|
||||
ctx.has_residual,
|
||||
ctx.is_rms_norm,
|
||||
x_dtype=ctx.x_dtype,
|
||||
)
|
||||
return (
|
||||
dx.reshape(ctx.x_shape_og),
|
||||
do.reshape(ctx.o_shape_og),
|
||||
dw,
|
||||
db,
|
||||
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class LayerNormSwishGateLinearFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
o,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
o_shape_og = o.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
o = o.reshape(-1, o.shape[-1])
|
||||
if residual is not None:
|
||||
assert residual.shape == x_shape_og
|
||||
residual = residual.reshape(-1, residual.shape[-1])
|
||||
residual_dtype = (
|
||||
residual.dtype
|
||||
if residual is not None
|
||||
else (torch.float32 if residual_in_fp32 else None)
|
||||
)
|
||||
y, mean, rstd, residual_out = _layer_norm_fwd(
|
||||
x,
|
||||
o,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
eps,
|
||||
residual,
|
||||
residual_dtype=residual_dtype,
|
||||
is_rms_norm=is_rms_norm
|
||||
)
|
||||
y = y.reshape(x_shape_og)
|
||||
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
||||
linear_weight = linear_weight.to(dtype)
|
||||
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
||||
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
||||
# We don't store y, will be recomputed in the backward pass to save memory
|
||||
ctx.save_for_backward(residual_out, o, norm_weight, norm_bias, linear_weight, mean, rstd)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.o_shape_og = o_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.prenorm = prenorm
|
||||
ctx.x_dtype = x.dtype
|
||||
ctx.linear_bias_is_none = linear_bias is None
|
||||
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, dout, *args):
|
||||
x, o, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
||||
dout = dout.reshape(-1, dout.shape[-1])
|
||||
dy = F.linear(dout, linear_weight.t())
|
||||
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
||||
assert dy.shape == x.shape
|
||||
if ctx.prenorm:
|
||||
dresidual = args[0]
|
||||
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
||||
assert dresidual.shape == x.shape
|
||||
else:
|
||||
dresidual = None
|
||||
dx, do, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
|
||||
dy,
|
||||
x,
|
||||
o,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
ctx.eps,
|
||||
mean,
|
||||
rstd,
|
||||
dresidual=dresidual,
|
||||
has_residual=ctx.has_residual,
|
||||
is_rms_norm=ctx.is_rms_norm,
|
||||
x_dtype=ctx.x_dtype,
|
||||
recompute_output=True,
|
||||
)
|
||||
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
||||
return (
|
||||
dx.reshape(ctx.x_shape_og),
|
||||
do.reshape(ctx.o_shape_og),
|
||||
dnorm_weight,
|
||||
dnorm_bias,
|
||||
dlinear_weight,
|
||||
dlinear_bias,
|
||||
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def layer_norm_swish_gate_fn(
|
||||
x,
|
||||
o,
|
||||
weight,
|
||||
bias,
|
||||
residual=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
eps=1e-6
|
||||
):
|
||||
return LayerNormSwishGateFn.apply(
|
||||
x,
|
||||
o,
|
||||
weight,
|
||||
bias,
|
||||
residual,
|
||||
eps,
|
||||
prenorm,
|
||||
residual_in_fp32,
|
||||
False
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_swish_gate_fn(
|
||||
x,
|
||||
o,
|
||||
weight,
|
||||
bias,
|
||||
residual=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
eps=1e-6
|
||||
):
|
||||
return LayerNormSwishGateFn.apply(
|
||||
x,
|
||||
o,
|
||||
weight,
|
||||
bias,
|
||||
residual,
|
||||
eps,
|
||||
prenorm,
|
||||
residual_in_fp32,
|
||||
True
|
||||
)
|
||||
|
||||
|
||||
def layer_norm_swish_gate_linear_fn(
|
||||
x,
|
||||
o,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
eps=1e-6
|
||||
):
|
||||
return LayerNormSwishGateLinearFn.apply(
|
||||
x,
|
||||
o,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual,
|
||||
eps,
|
||||
prenorm,
|
||||
residual_in_fp32,
|
||||
False
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_swish_gate_linear_fn(
|
||||
x,
|
||||
o,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
eps=1e-6
|
||||
):
|
||||
return LayerNormSwishGateLinearFn.apply(
|
||||
x,
|
||||
o,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual,
|
||||
eps,
|
||||
prenorm,
|
||||
residual_in_fp32,
|
||||
True
|
||||
)
|
||||
|
||||
|
||||
class FusedLayerNormSwishGate(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
elementwise_affine: bool = True,
|
||||
eps=1e-5
|
||||
) -> FusedLayerNormSwishGate:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.eps = eps
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.hidden_size}"
|
||||
if not self.elementwise_affine:
|
||||
s += f", elementwise_affine={self.elementwise_affine}"
|
||||
s += f", eps={self.eps}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return layer_norm_swish_gate_fn(
|
||||
x,
|
||||
o,
|
||||
self.weight,
|
||||
self.bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32
|
||||
)
|
||||
|
||||
|
||||
class FusedRMSNormSwishGate(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
elementwise_affine: bool = True,
|
||||
eps=1e-5
|
||||
) -> FusedRMSNormSwishGate:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.eps = eps
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.hidden_size}"
|
||||
if not self.elementwise_affine:
|
||||
s += f", elementwise_affine={self.elementwise_affine}"
|
||||
s += f", eps={self.eps}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return rms_norm_swish_gate_fn(
|
||||
x,
|
||||
o,
|
||||
self.weight,
|
||||
self.bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32
|
||||
)
|
||||
|
||||
|
||||
class FusedLayerNormSwishGateLinear(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
elementwise_affine: bool = True,
|
||||
eps=1e-5
|
||||
) -> FusedLayerNormSwishGateLinear:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.eps = eps
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.hidden_size}"
|
||||
if not self.elementwise_affine:
|
||||
s += f", elementwise_affine={self.elementwise_affine}"
|
||||
s += f", eps={self.eps}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return layer_norm_swish_gate_linear_fn(
|
||||
x,
|
||||
o,
|
||||
self.weight,
|
||||
self.bias,
|
||||
weight,
|
||||
bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32
|
||||
)
|
||||
|
||||
|
||||
class FusedRMSNormSwishGateLinear(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
elementwise_affine: bool = True,
|
||||
eps=1e-5
|
||||
) -> FusedRMSNormSwishGateLinear:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.eps = eps
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.hidden_size}"
|
||||
if not self.elementwise_affine:
|
||||
s += f", elementwise_affine={self.elementwise_affine}"
|
||||
s += f", eps={self.eps}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return rms_norm_swish_gate_linear_fn(
|
||||
x,
|
||||
o,
|
||||
self.weight,
|
||||
self.bias,
|
||||
weight,
|
||||
bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32
|
||||
)
|
||||
216
finetune/lora/v6/fla/modules/l2norm.py
vendored
Normal file
216
finetune/lora/v6/fla/modules/l2norm.py
vendored
Normal file
@@ -0,0 +1,216 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["N"],
|
||||
)
|
||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
||||
@triton.jit
|
||||
def _l2_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
X += row * stride_x_row
|
||||
Y += row * stride_x_row
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0)
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
# tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
y = x * rstd
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["N"],
|
||||
)
|
||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
||||
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
||||
# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
||||
@triton.jit
|
||||
def _l2_norm_bwd_kernel(
|
||||
X, # pointer to the input
|
||||
# Y, # pointer to the output to be recomputed
|
||||
DY, # pointer to the output gradient
|
||||
DX, # pointer to the input gradient
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the elements of X, DX, and DY it should compute.
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
X += row * stride_x_row
|
||||
DX += row * stride_x_row
|
||||
DY += row * stride_x_row
|
||||
|
||||
# Y += row * stride_y_row
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
x = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(x * x)
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
# tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
# y = x * rstd
|
||||
dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
dy = tl.where(cols < N, dy, 0.0)
|
||||
# dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
|
||||
dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
|
||||
def _l2_norm_fwd(
|
||||
x, eps=1e-6
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
M, N = x.shape
|
||||
assert x.stride(-1) == 1
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
assert y.stride(-1) == 1
|
||||
N = x.shape[-1]
|
||||
M = x.shape[0]
|
||||
# rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
with torch.cuda.device(x.device.index):
|
||||
_l2_norm_fwd_1pass_kernel[(M,)](
|
||||
x,
|
||||
y,
|
||||
x.stride(0),
|
||||
N,
|
||||
eps,
|
||||
# is_rms_norm,
|
||||
BLOCK_N,
|
||||
# residual is not None,
|
||||
# residual_out is not None,
|
||||
# bias is not None,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
def _l2_norm_bwd(
|
||||
x, dy, eps=1e-5,
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
x = x.reshape(-1, dy.shape[-1])
|
||||
dy = dy.reshape(-1, dy.shape[-1])
|
||||
if dy.stride(-1) != 1:
|
||||
dy = dy.contiguous()
|
||||
assert dy.shape == x.shape
|
||||
# allocate output
|
||||
dx = torch.empty_like(x)
|
||||
N = x.shape[-1]
|
||||
M = x.shape[0]
|
||||
assert x.stride(-1) == 1
|
||||
assert dy.stride(-1) == 1
|
||||
# rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
with torch.cuda.device(x.device.index):
|
||||
_l2_norm_bwd_kernel[(M,)](
|
||||
x,
|
||||
dy,
|
||||
dx,
|
||||
x.stride(0),
|
||||
N,
|
||||
eps,
|
||||
BLOCK_N,
|
||||
)
|
||||
return dx.reshape(x_shape_og)
|
||||
|
||||
|
||||
class L2NormFN(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
eps=1e-6,
|
||||
):
|
||||
# reshape input data into 2D tensor
|
||||
y = _l2_norm_fwd(x, eps)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.x_dtype = x.dtype
|
||||
ctx.save_for_backward(x)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy, *args):
|
||||
x, = ctx.saved_tensors
|
||||
dx = _l2_norm_bwd(
|
||||
x,
|
||||
dy,
|
||||
ctx.eps,
|
||||
)
|
||||
return (
|
||||
dx,
|
||||
None
|
||||
)
|
||||
|
||||
l2_norm_fn = L2NormFN.apply
|
||||
|
||||
if __name__ == '__main__':
|
||||
x = torch.rand(10, 10, 100).cuda().requires_grad_(True)
|
||||
y = torch.nn.functional.normalize(x, dim=-1, p=2)
|
||||
dy = torch.rand_like(y)
|
||||
y.backward(dy, retain_graph=True)
|
||||
x_grad, x.grad = x.grad, None
|
||||
y2 = l2_norm_fn(x, 1e-6)
|
||||
print((y-y2).abs().max())
|
||||
y2.backward(dy, retain_graph=True)
|
||||
x_grad2, x.grad = x.grad, None
|
||||
print((x_grad2-x_grad).abs().max())
|
||||
breakpoint()
|
||||
|
||||
|
||||
|
||||
|
||||
802
finetune/lora/v6/fla/modules/layernorm.py
vendored
Normal file
802
finetune/lora/v6/fla/modules/layernorm.py
vendored
Normal file
@@ -0,0 +1,802 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
|
||||
# Implement residual + layer_norm / rms_norm.
|
||||
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
||||
dtype = x.dtype
|
||||
if upcast:
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
residual = residual.float() if residual is not None else residual
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
||||
dtype
|
||||
)
|
||||
return out if not prenorm else (out, x)
|
||||
|
||||
|
||||
def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
|
||||
dtype = x.dtype
|
||||
if upcast:
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
residual = residual.float() if residual is not None else residual
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + \
|
||||
bias if bias is not None else (x * rstd * weight)
|
||||
out = out.to(dtype)
|
||||
return out if not prenorm else (out, x)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
||||
)
|
||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
RESIDUAL, # pointer to the residual
|
||||
RESIDUAL_OUT, # pointer to the residual
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_res_row,
|
||||
stride_res_out_row,
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_RESIDUAL: tl.constexpr,
|
||||
STORE_RESIDUAL_OUT: tl.constexpr,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
X += row * stride_x_row
|
||||
Y += row * stride_y_row
|
||||
if HAS_RESIDUAL:
|
||||
RESIDUAL += row * stride_res_row
|
||||
if STORE_RESIDUAL_OUT:
|
||||
RESIDUAL_OUT += row * stride_res_out_row
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_RESIDUAL:
|
||||
residual = tl.load(RESIDUAL + cols, mask=cols <
|
||||
N, other=0.0).to(tl.float32)
|
||||
x += residual
|
||||
if STORE_RESIDUAL_OUT:
|
||||
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
if HAS_WEIGHT:
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
|
||||
y = x_hat * w if HAS_WEIGHT else x_hat
|
||||
if HAS_BIAS:
|
||||
y = y + b
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False
|
||||
):
|
||||
if residual is not None:
|
||||
residual_dtype = residual.dtype
|
||||
M, N = x.shape
|
||||
assert x.stride(-1) == 1
|
||||
if residual is not None:
|
||||
assert residual.stride(-1) == 1
|
||||
assert residual.shape == (M, N)
|
||||
if weight is not None:
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
|
||||
residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
|
||||
assert residual_out.stride(-1) == 1
|
||||
else:
|
||||
residual_out = None
|
||||
mean = torch.empty((M,), dtype=torch.float32,
|
||||
device="cuda") if not is_rms_norm else None
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[(M,)](
|
||||
x,
|
||||
y,
|
||||
weight,
|
||||
bias,
|
||||
residual,
|
||||
residual_out,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
y.stride(0),
|
||||
residual.stride(0) if residual is not None else 0,
|
||||
residual_out.stride(0) if residual_out is not None else 0,
|
||||
N,
|
||||
eps,
|
||||
is_rms_norm,
|
||||
BLOCK_N,
|
||||
residual is not None,
|
||||
residual_out is not None,
|
||||
weight is not None,
|
||||
bias is not None,
|
||||
)
|
||||
# residual_out is None if residual is None and residual_dtype == input_dtype
|
||||
return y, mean, rstd, residual_out if residual_out is not None else x
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
||||
)
|
||||
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
||||
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
||||
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_bwd_kernel(
|
||||
X, # pointer to the input
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Y, # pointer to the output to be recomputed
|
||||
DY, # pointer to the output gradient
|
||||
DX, # pointer to the input gradient
|
||||
DW, # pointer to the partial sum of weights gradient
|
||||
DB, # pointer to the partial sum of biases gradient
|
||||
DRESIDUAL,
|
||||
DRESIDUAL_IN,
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_dy_row,
|
||||
stride_dx_row,
|
||||
stride_dres_row,
|
||||
stride_dres_in_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
rows_per_program,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_DRESIDUAL: tl.constexpr,
|
||||
STORE_DRESIDUAL: tl.constexpr,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
RECOMPUTE_OUTPUT: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the elements of X, DX, and DY it should compute.
|
||||
row_block_id = tl.program_id(0)
|
||||
row_start = row_block_id * rows_per_program
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
mask = cols < N
|
||||
X += row_start * stride_x_row
|
||||
if HAS_DRESIDUAL:
|
||||
DRESIDUAL += row_start * stride_dres_row
|
||||
if STORE_DRESIDUAL:
|
||||
DRESIDUAL_IN += row_start * stride_dres_in_row
|
||||
DY += row_start * stride_dy_row
|
||||
DX += row_start * stride_dx_row
|
||||
if RECOMPUTE_OUTPUT:
|
||||
Y += row_start * stride_y_row
|
||||
if HAS_WEIGHT:
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
row_end = min((row_block_id + 1) * rows_per_program, M)
|
||||
for row in range(row_start, row_end):
|
||||
# Load data to SRAM
|
||||
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
||||
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.load(Mean + row)
|
||||
rstd = tl.load(Rstd + row)
|
||||
# Compute dx
|
||||
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
xhat = tl.where(mask, xhat, 0.0)
|
||||
if RECOMPUTE_OUTPUT:
|
||||
y = xhat * w if HAS_WEIGHT else xhat
|
||||
if HAS_BIAS:
|
||||
y = y + b
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
wdy = dy
|
||||
if HAS_WEIGHT:
|
||||
wdy = dy * w
|
||||
dw += dy * xhat
|
||||
if HAS_BIAS:
|
||||
db += dy
|
||||
if not IS_RMS_NORM:
|
||||
c1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
c2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat * c1 + c2)) * rstd
|
||||
else:
|
||||
c1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
dx = (wdy - xhat * c1) * rstd
|
||||
if HAS_DRESIDUAL:
|
||||
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
||||
dx += dres
|
||||
# Write dx
|
||||
if STORE_DRESIDUAL:
|
||||
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
||||
tl.store(DX + cols, dx, mask=mask)
|
||||
|
||||
X += stride_x_row
|
||||
if HAS_DRESIDUAL:
|
||||
DRESIDUAL += stride_dres_row
|
||||
if STORE_DRESIDUAL:
|
||||
DRESIDUAL_IN += stride_dres_in_row
|
||||
if RECOMPUTE_OUTPUT:
|
||||
Y += stride_y_row
|
||||
DY += stride_dy_row
|
||||
DX += stride_dx_row
|
||||
if HAS_WEIGHT:
|
||||
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
||||
if HAS_BIAS:
|
||||
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_bwd(
|
||||
dy,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
mean,
|
||||
rstd,
|
||||
dresidual=None,
|
||||
has_residual=False,
|
||||
is_rms_norm=False,
|
||||
x_dtype=None,
|
||||
recompute_output=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
assert x.stride(-1) == 1
|
||||
assert dy.stride(-1) == 1
|
||||
assert dy.shape == (M, N)
|
||||
if dresidual is not None:
|
||||
assert dresidual.stride(-1) == 1
|
||||
assert dresidual.shape == (M, N)
|
||||
if weight is not None:
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
dx = (
|
||||
torch.empty_like(x)
|
||||
if x_dtype is None
|
||||
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
||||
)
|
||||
dresidual_in = torch.empty_like(
|
||||
x) if has_residual and dx.dtype != x.dtype else None
|
||||
y = torch.empty(M, N, dtype=dy.dtype,
|
||||
device=dy.device) if recompute_output else None
|
||||
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
||||
_dw = (
|
||||
torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
||||
if weight is not None
|
||||
else None
|
||||
)
|
||||
_db = (
|
||||
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
||||
if bias is not None
|
||||
else None
|
||||
)
|
||||
rows_per_program = math.ceil(M / sm_count)
|
||||
grid = (sm_count,)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_bwd_kernel[grid](
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
y,
|
||||
dy,
|
||||
dx,
|
||||
_dw,
|
||||
_db,
|
||||
dresidual,
|
||||
dresidual_in,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
0 if not recompute_output else y.stride(0),
|
||||
dy.stride(0),
|
||||
dx.stride(0),
|
||||
dresidual.stride(0) if dresidual is not None else 0,
|
||||
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
||||
M,
|
||||
N,
|
||||
eps,
|
||||
rows_per_program,
|
||||
is_rms_norm,
|
||||
BLOCK_N,
|
||||
dresidual is not None,
|
||||
dresidual_in is not None,
|
||||
weight is not None,
|
||||
bias is not None,
|
||||
)
|
||||
dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
|
||||
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
||||
# Don't need to compute dresidual_in separately in this case
|
||||
if has_residual and dx.dtype == x.dtype:
|
||||
dresidual_in = dx
|
||||
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
assert residual.shape == x_shape_og
|
||||
residual = residual.reshape(-1, residual.shape[-1])
|
||||
residual_dtype = (
|
||||
residual.dtype
|
||||
if residual is not None
|
||||
else (torch.float32 if residual_in_fp32 else None)
|
||||
)
|
||||
y, mean, rstd, residual_out = _layer_norm_fwd(
|
||||
x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm
|
||||
)
|
||||
ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.prenorm = prenorm
|
||||
ctx.x_dtype = x.dtype
|
||||
y = y.reshape(x_shape_og)
|
||||
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, dy, *args):
|
||||
x, weight, bias, mean, rstd = ctx.saved_tensors
|
||||
dy = dy.reshape(-1, dy.shape[-1])
|
||||
assert dy.shape == x.shape
|
||||
if ctx.prenorm:
|
||||
dresidual = args[0]
|
||||
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
||||
assert dresidual.shape == x.shape
|
||||
else:
|
||||
dresidual = None
|
||||
dx, dw, db, dresidual_in = _layer_norm_bwd(
|
||||
dy,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
ctx.eps,
|
||||
mean,
|
||||
rstd,
|
||||
dresidual,
|
||||
ctx.has_residual,
|
||||
ctx.is_rms_norm,
|
||||
x_dtype=ctx.x_dtype,
|
||||
)
|
||||
return (
|
||||
dx.reshape(ctx.x_shape_og),
|
||||
dw,
|
||||
db,
|
||||
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def layer_norm_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)
|
||||
|
||||
|
||||
def rms_norm_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
residual=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
eps=1e-6
|
||||
):
|
||||
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5
|
||||
) -> LayerNorm:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.eps = eps
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.hidden_size}"
|
||||
if not self.elementwise_affine:
|
||||
s += f", elementwise_affine={self.elementwise_affine}"
|
||||
s += f", eps={self.eps}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return layer_norm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5
|
||||
) -> RMSNorm:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.eps = eps
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.hidden_size}"
|
||||
if not self.elementwise_affine:
|
||||
s += f", elementwise_affine={self.elementwise_affine}"
|
||||
s += f", eps={self.eps}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return rms_norm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
)
|
||||
|
||||
|
||||
class LayerNormLinearFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
assert residual.shape == x_shape_og
|
||||
residual = residual.reshape(-1, residual.shape[-1])
|
||||
residual_dtype = (
|
||||
residual.dtype
|
||||
if residual is not None
|
||||
else (torch.float32 if residual_in_fp32 else None)
|
||||
)
|
||||
y, mean, rstd, residual_out = _layer_norm_fwd(
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
eps,
|
||||
residual,
|
||||
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
|
||||
residual_dtype=residual_dtype,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
y = y.reshape(x_shape_og)
|
||||
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
||||
linear_weight = linear_weight.to(dtype)
|
||||
linear_bias = linear_bias.to(
|
||||
dtype) if linear_bias is not None else None
|
||||
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
||||
# We don't store y, will be recomputed in the backward pass to save memory
|
||||
ctx.save_for_backward(residual_out, norm_weight,
|
||||
norm_bias, linear_weight, mean, rstd)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.prenorm = prenorm
|
||||
ctx.x_dtype = x.dtype
|
||||
ctx.linear_bias_is_none = linear_bias is None
|
||||
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, dout, *args):
|
||||
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
||||
dout = dout.reshape(-1, dout.shape[-1])
|
||||
dy = F.linear(dout, linear_weight.t())
|
||||
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
||||
assert dy.shape == x.shape
|
||||
if ctx.prenorm:
|
||||
dresidual = args[0]
|
||||
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
||||
assert dresidual.shape == x.shape
|
||||
else:
|
||||
dresidual = None
|
||||
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
|
||||
dy,
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
ctx.eps,
|
||||
mean,
|
||||
rstd,
|
||||
dresidual,
|
||||
ctx.has_residual,
|
||||
ctx.is_rms_norm,
|
||||
x_dtype=ctx.x_dtype,
|
||||
recompute_output=True,
|
||||
)
|
||||
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
||||
return (
|
||||
dx.reshape(ctx.x_shape_og),
|
||||
dnorm_weight,
|
||||
dnorm_bias,
|
||||
dlinear_weight,
|
||||
dlinear_bias,
|
||||
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def layer_norm_linear_fn(
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual=None,
|
||||
eps=1e-6,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
return LayerNormLinearFn.apply(
|
||||
x,
|
||||
norm_weight,
|
||||
norm_bias,
|
||||
linear_weight,
|
||||
linear_bias,
|
||||
residual,
|
||||
eps,
|
||||
prenorm,
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
)
|
||||
|
||||
|
||||
class LayerNormLinear(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
elementwise_affine: bool = True,
|
||||
eps=1e-5
|
||||
) -> LayerNormLinear:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.eps = eps
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.hidden_size}"
|
||||
if not self.elementwise_affine:
|
||||
s += f", elementwise_affine={self.elementwise_affine}"
|
||||
s += f", eps={self.eps}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return layer_norm_linear_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
weight,
|
||||
bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
is_rms_norm=False
|
||||
)
|
||||
|
||||
|
||||
class RMSNormLinear(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
elementwise_affine: bool = True,
|
||||
eps=1e-5
|
||||
) -> RMSNormLinear:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.elementwise_affine = elementwise_affine
|
||||
self.eps = eps
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.hidden_size}"
|
||||
if not self.elementwise_affine:
|
||||
s += f", elementwise_affine={self.elementwise_affine}"
|
||||
s += f", eps={self.eps}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return layer_norm_linear_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
weight,
|
||||
bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
is_rms_norm=True
|
||||
)
|
||||
310
finetune/lora/v6/fla/modules/rotary.py
vendored
Normal file
310
finetune/lora/v6/fla/modules/rotary.py
vendored
Normal file
@@ -0,0 +1,310 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from fla.ops.rotary import apply_rotary
|
||||
|
||||
|
||||
def rotate_half(x, interleaved=False):
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
return torch.cat(
|
||||
[x[..., :ro_dim] * cos +
|
||||
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
class ApplyRotaryEmb(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
out = apply_rotary(
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
interleaved=interleaved,
|
||||
inplace=inplace,
|
||||
)
|
||||
if isinstance(seqlen_offsets, int):
|
||||
# Can't save int with save_for_backward
|
||||
ctx.save_for_backward(cos, sin, cu_seqlens)
|
||||
ctx.seqlen_offsets = seqlen_offsets
|
||||
else:
|
||||
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
||||
ctx.seqlen_offsets = None
|
||||
ctx.interleaved = interleaved
|
||||
ctx.inplace = inplace
|
||||
ctx.max_seqlen = max_seqlen
|
||||
return out if not inplace else x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
seqlen_offsets = ctx.seqlen_offsets
|
||||
if seqlen_offsets is None:
|
||||
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
||||
else:
|
||||
cos, sin, cu_seqlens = ctx.saved_tensors
|
||||
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
|
||||
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
|
||||
if not ctx.interleaved and not ctx.inplace:
|
||||
do = do.clone()
|
||||
dx = apply_rotary(
|
||||
do,
|
||||
cos,
|
||||
sin,
|
||||
seqlen_offsets=seqlen_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=ctx.max_seqlen,
|
||||
interleaved=ctx.interleaved,
|
||||
inplace=ctx.inplace,
|
||||
conjugate=True,
|
||||
)
|
||||
return dx, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen_rotary, rotary_dim / 2)
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
inplace: if True, apply rotary embedding in-place.
|
||||
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
cu_seqlens: (batch + 1,) or None
|
||||
max_seqlen: int
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim)
|
||||
rotary_dim must be <= headdim
|
||||
Apply rotary embedding to the first rotary_dim of x.
|
||||
"""
|
||||
return ApplyRotaryEmb.apply(
|
||||
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
|
||||
)
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
"""
|
||||
The rotary position embeddings from RoFormer_ (Su et. al).
|
||||
A crucial insight from the method is that the query and keys are
|
||||
transformed by rotation matrices which depend on the relative positions.
|
||||
|
||||
Other implementations are available in the Rotary Transformer repo_ and in
|
||||
GPT-NeoX_, GPT-NeoX was an inspiration
|
||||
|
||||
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
||||
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
||||
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
||||
|
||||
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
||||
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
||||
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base=10000.0,
|
||||
interleaved=False,
|
||||
scale_base=None,
|
||||
pos_idx_in_fp32=True,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
||||
otherwise they might be in lower precision.
|
||||
This option was added because previously (before 2023-07-02), when we construct
|
||||
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
||||
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
||||
self.inv_freq would be bf16, and the position indices are also in bf16.
|
||||
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
||||
embeddings for some positions will coincide.
|
||||
To maintain compatibility with models previously trained in pure bf16,
|
||||
we add this option.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base = float(base)
|
||||
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = self._compute_inv_freq(device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.interleaved = interleaved
|
||||
self.scale_base = scale_base
|
||||
scale = (
|
||||
(torch.arange(0, dim, 2, device=device,
|
||||
dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||
if scale_base is not None
|
||||
else None
|
||||
)
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _compute_inv_freq(self, device=None):
|
||||
return 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# if we're on a new device (possibly due to tracing for instance),
|
||||
# or if we're switching from inference mode to training
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
||||
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
||||
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
||||
if self.pos_idx_in_fp32:
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
||||
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
||||
# cos & sin output to change significantly.
|
||||
# We want to recompute self.inv_freq if it was not loaded in fp32
|
||||
if self.inv_freq.dtype != torch.float32:
|
||||
inv_freq = self._compute_inv_freq(device=device)
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
else:
|
||||
t = torch.arange(seqlen, device=device,
|
||||
dtype=self.inv_freq.dtype)
|
||||
inv_freq = self.inv_freq
|
||||
# Don't do einsum, it converts fp32 to fp16 under AMP
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
else:
|
||||
power = (
|
||||
torch.arange(seqlen, dtype=self.scale.dtype,
|
||||
device=self.scale.device)
|
||||
- seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(
|
||||
device=power.device) ** rearrange(power, "s -> s 1")
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
seqlen_offset: Union[int, torch.Tensor] = 0,
|
||||
max_seqlen: Optional[int] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
||||
else it's just q of shape (batch, seqlen, nheads, headdim)
|
||||
kv: (batch, seqlen, 2, nheads, headdim)
|
||||
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
||||
Most commonly used in inference when we have KV cache.
|
||||
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
|
||||
should pass in max_seqlen, which will update the cos / sin cache up to that length.
|
||||
Apply rotary embedding *inplace* to qkv and / or kv.
|
||||
"""
|
||||
seqlen = q.shape[1]
|
||||
if max_seqlen is not None:
|
||||
self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype)
|
||||
elif isinstance(seqlen_offset, int):
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=q.device, dtype=q.dtype)
|
||||
if self.scale is None:
|
||||
q = apply_rotary_emb_func(
|
||||
q,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
k = apply_rotary_emb_func(
|
||||
k,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
|
||||
else:
|
||||
q = apply_rotary_emb_func(
|
||||
q,
|
||||
self._cos_cached,
|
||||
self._sin_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
k = apply_rotary_emb_func(
|
||||
k,
|
||||
self._cos_k_cached,
|
||||
self._sin_k_cached,
|
||||
interleaved=self.interleaved,
|
||||
seqlen_offsets=seqlen_offset,
|
||||
)
|
||||
|
||||
return q, k
|
||||
Reference in New Issue
Block a user