144 lines
5.5 KiB
Python
144 lines
5.5 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from typing import Optional
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from einops import rearrange
|
||
|
from transformers.activations import ACT2FN
|
||
|
|
||
|
from fla.modules import FusedRMSNormSwishGate, RMSNorm
|
||
|
from fla.ops.simple_gla import chunk_simple_gla
|
||
|
|
||
|
|
||
|
class SimpleGatedLinearAttention(nn.Module):
|
||
|
r"""
|
||
|
The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
|
||
|
This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
|
||
|
|
||
|
Args:
|
||
|
mode (str, Optional):
|
||
|
Which GLA kernel to use.
|
||
|
Currently available: `chunk`.
|
||
|
Default: `chunk`.
|
||
|
hidden_size (int, Optional):
|
||
|
The hidden size of the input. Default: 1024.
|
||
|
expand_k (float, Optional):
|
||
|
The expansion ratio for the key dim. Default: 0.5.
|
||
|
expand_v (float, Optional):
|
||
|
The expansion ratio for the value dim. Default: 1.0.
|
||
|
num_heads (int, Optional):
|
||
|
The number of heads. Default: 4.
|
||
|
gate_fn (str, Optional):
|
||
|
The activation function for the output gate. Default: `swish`.
|
||
|
elementwise_affine (bool, Optional):
|
||
|
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
|
||
|
norm_eps (float, Optional):
|
||
|
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
|
||
|
gate_logit_normalizer (int, Optional):
|
||
|
The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
|
||
|
fuse_norm (bool, Optional):
|
||
|
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
|
||
|
layer_idx (int, Optional):
|
||
|
The index of the layer. Default: None.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
mode: str = 'chunk',
|
||
|
hidden_size: int = 1024,
|
||
|
expand_k: float = 1.0,
|
||
|
expand_v: float = 2.0,
|
||
|
num_heads: int = 4,
|
||
|
gate_fn: str = 'swish',
|
||
|
elementwise_affine: Optional[bool] = True,
|
||
|
norm_eps: float = 1e-5,
|
||
|
gate_logit_normalizer: int = 16,
|
||
|
fuse_norm: bool = True,
|
||
|
**kwargs
|
||
|
) -> SimpleGatedLinearAttention:
|
||
|
super().__init__()
|
||
|
self.hidden_size = hidden_size
|
||
|
|
||
|
self.mode = mode
|
||
|
self.key_dim = int(hidden_size * expand_k)
|
||
|
self.value_dim = int(hidden_size * expand_v)
|
||
|
assert mode in ['chunk'], f"Not suppoerted mode `{mode}`."
|
||
|
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
||
|
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
||
|
self.num_heads = num_heads
|
||
|
self.head_qk_dim = self.key_dim // num_heads
|
||
|
self.head_v_dim = self.value_dim // num_heads
|
||
|
self.gate_fn = ACT2FN[gate_fn]
|
||
|
|
||
|
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||
|
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||
|
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||
|
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||
|
|
||
|
self.gk_proj = nn.Linear(hidden_size, self.num_heads)
|
||
|
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
||
|
|
||
|
if gate_fn == 'swish' and fuse_norm:
|
||
|
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
|
||
|
self.fuse_norm_and_gate = True
|
||
|
else:
|
||
|
self.fuse_norm_and_gate = False
|
||
|
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
|
||
|
|
||
|
self.gate_logit_normalizer = gate_logit_normalizer
|
||
|
|
||
|
self.apply(self._initialize_weights)
|
||
|
|
||
|
def _initialize_weights(self, module: nn.Module):
|
||
|
if getattr(module, "_is_hf_initialized", False):
|
||
|
return
|
||
|
if isinstance(module, nn.Linear):
|
||
|
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||
|
if module.bias is not None:
|
||
|
nn.init.zeros_(module.bias)
|
||
|
module._is_hf_initialized = True
|
||
|
|
||
|
def forward(self, x):
|
||
|
mode = self.mode
|
||
|
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||
|
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||
|
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||
|
gk = rearrange(self.gk_proj(x), 'b n h -> b h n')
|
||
|
gk = (F.logsigmoid(gk) / self.gate_logit_normalizer)
|
||
|
|
||
|
if mode == 'chunk':
|
||
|
o = chunk_simple_gla(q, k, v, gk)
|
||
|
else:
|
||
|
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
||
|
|
||
|
o = rearrange(o, 'b h l d -> b l h d')
|
||
|
g = self.g_proj(x)
|
||
|
|
||
|
if self.fuse_norm_and_gate:
|
||
|
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
|
||
|
o = self.g_norm_swish_gate(o, g)
|
||
|
o = rearrange(o, 'b l h d -> b l (h d)')
|
||
|
else:
|
||
|
o = self.g_norm(o)
|
||
|
o = rearrange(o, 'b l h d -> b l (h d)')
|
||
|
o = o * self.gate_fn(g)
|
||
|
o = self.o_proj(o)
|
||
|
return o
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
batch = 4
|
||
|
seq_len = 1024
|
||
|
|
||
|
hidden_size = 2048
|
||
|
x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
|
||
|
model = SimpleGatedLinearAttention(hidden_size=hidden_size, mode='chunk').to(torch.bfloat16).cuda()
|
||
|
y = model(x)
|
||
|
print(y.shape)
|
||
|
y.sum().backward()
|
||
|
print(x.grad.shape)
|