254 lines
9.9 KiB
Python
254 lines
9.9 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
# Sect4.2 of Linear Transformers Are Secretly Fast Weight Programmers https://arxiv.org/abs/2102.11174
|
||
|
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from typing import Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from einops import rearrange
|
||
|
from transformers.cache_utils import Cache
|
||
|
|
||
|
|
||
|
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution, LayerNorm
|
||
|
from fla.modules.rotary import RotaryEmbedding
|
||
|
from fla.ops.delta_rule import (fused_chunk_delta_rule,
|
||
|
fused_recurrent_linear_attn_delta_rule,
|
||
|
chunk_delta_rule)
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
|
||
|
def simple_norm(x):
|
||
|
return (F.normalize(x, dim=-1) * x.shape[-1] ** 0.5).to(x)
|
||
|
|
||
|
|
||
|
# @torch.jit.script
|
||
|
def elu_p1(x):
|
||
|
return (F.elu(x, 1., False) + 1.).to(x)
|
||
|
|
||
|
|
||
|
# @torch.jit.script
|
||
|
def sum_norm(x):
|
||
|
return (x / x.sum(-1, keepdim=True)).to(x)
|
||
|
|
||
|
|
||
|
# @torch.jit.script
|
||
|
def elu_norm(x):
|
||
|
dtype = x.dtype
|
||
|
x = F.elu(x, 1., False) + 1.
|
||
|
return (x / x.sum(-1, keepdim=True)).to(dtype)
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
|
||
|
class DeltaNet(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
d_model: int = None,
|
||
|
hidden_size: int = 1024,
|
||
|
expand_k: float = 1.0,
|
||
|
expand_v: float = 1.0,
|
||
|
num_heads: int = 4,
|
||
|
mode: str = 'fused_chunk',
|
||
|
chunk_size: int = 16,
|
||
|
use_beta: bool = True,
|
||
|
use_gate: bool = True,
|
||
|
use_rope: bool = False,
|
||
|
use_output_norm: bool = True,
|
||
|
use_elu: bool = False,
|
||
|
use_short_conv: bool = True,
|
||
|
conv_size: int = 4,
|
||
|
conv_bias: bool = False,
|
||
|
share_conv_kernel: bool = False,
|
||
|
layer_idx: int = None,
|
||
|
qk_activation: str = 'silu',
|
||
|
qk_norm: str = None,
|
||
|
save_memory: str = False,
|
||
|
**kwargs
|
||
|
) -> DeltaNet:
|
||
|
super().__init__()
|
||
|
self.mode = mode
|
||
|
self.qk_activation = qk_activation
|
||
|
self.qk_norm = qk_norm
|
||
|
assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
|
||
|
assert self.qk_norm in ['l2', 'sum']
|
||
|
if d_model is not None:
|
||
|
hidden_size = d_model
|
||
|
self.hidden_size = hidden_size
|
||
|
self.expand_k = expand_k
|
||
|
self.expand_v = expand_v
|
||
|
self.num_heads = num_heads
|
||
|
self.chunk_size = chunk_size
|
||
|
self.use_gate = use_gate
|
||
|
self.use_output_norm = use_output_norm
|
||
|
self.use_short_conv = use_short_conv
|
||
|
self.conv_size = conv_size
|
||
|
self.conv_bias = conv_bias
|
||
|
self.share_conv_kernel = share_conv_kernel
|
||
|
|
||
|
self.key_dim = int(hidden_size * expand_k)
|
||
|
self.value_dim = int(hidden_size * expand_v)
|
||
|
self.head_qk_dim = self.key_dim // num_heads
|
||
|
self.head_v_dim = self.value_dim // num_heads
|
||
|
self.layer_idx = layer_idx
|
||
|
|
||
|
self.silu = torch.nn.SiLU()
|
||
|
|
||
|
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
||
|
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
||
|
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
||
|
|
||
|
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||
|
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||
|
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||
|
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
||
|
|
||
|
self.use_beta = use_beta
|
||
|
self.use_elu = use_elu
|
||
|
if self.use_beta:
|
||
|
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
|
||
|
if use_short_conv:
|
||
|
self.conv_size = conv_size
|
||
|
if share_conv_kernel:
|
||
|
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation=None)
|
||
|
else:
|
||
|
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
|
||
|
self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
|
||
|
self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
|
||
|
if use_gate:
|
||
|
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||
|
if self.use_gate:
|
||
|
self.norm = FusedRMSNormSwishGate(self.head_v_dim)
|
||
|
else:
|
||
|
self.norm = RMSNorm(self.head_v_dim)
|
||
|
self.apply(self._initialize_weights)
|
||
|
|
||
|
def _initialize_weights(self, module: nn.Module):
|
||
|
if getattr(module, "_is_hf_initialized", False):
|
||
|
return
|
||
|
if isinstance(module, nn.Linear):
|
||
|
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||
|
if module.bias is not None:
|
||
|
nn.init.zeros_(module.bias)
|
||
|
module._is_hf_initialized = True
|
||
|
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
past_key_values: Optional[Cache] = None,
|
||
|
use_cache: Optional[bool] = False,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
**kwargs
|
||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||
|
|
||
|
# change to inference mode.
|
||
|
mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode
|
||
|
last_state = past_key_values[self.layer_idx] if use_cache else None
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
if attention_mask.shape[-1] != hidden_states.shape[-2]:
|
||
|
attention_mask = attention_mask[:, -1:]
|
||
|
|
||
|
if self.use_short_conv:
|
||
|
conv_state = last_state[0] if use_cache else None
|
||
|
if self.share_conv_kernel:
|
||
|
# conv state is updated inplace
|
||
|
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
|
||
|
q = self.q_proj(hidden_states)
|
||
|
k = self.k_proj(hidden_states)
|
||
|
v = self.v_proj(hidden_states)
|
||
|
else:
|
||
|
conv_state_q = last_state[0] if use_cache else None
|
||
|
conv_state_k = last_state[1] if use_cache else None
|
||
|
conv_state_v = last_state[2] if use_cache else None
|
||
|
k = self.k_proj(hidden_states)
|
||
|
v = self.v_proj(hidden_states)
|
||
|
q = self.q_proj(hidden_states)
|
||
|
q = self.q_conv1d(q, attention_mask, conv_state_q)
|
||
|
k = self.k_conv1d(k, attention_mask, conv_state_k)
|
||
|
v = self.v_conv1d(v, attention_mask, conv_state_v)
|
||
|
else:
|
||
|
q = (self.q_proj(hidden_states))
|
||
|
k = (self.k_proj(hidden_states))
|
||
|
v = self.silu(self.v_proj(hidden_states))
|
||
|
|
||
|
# dealing with left-padding
|
||
|
if attention_mask is not None:
|
||
|
v = v.mul_(attention_mask.unsqueeze(-1))
|
||
|
|
||
|
q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v))
|
||
|
|
||
|
if self.qk_activation != 'silu':
|
||
|
if self.qk_activation == 'relu':
|
||
|
q, k = q.relu(), k.relu()
|
||
|
elif self.qk_activation == 'elu':
|
||
|
q, k = elu_p1(q), elu_p1(k)
|
||
|
elif self.qk_activation == 'identity':
|
||
|
pass
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
if self.qk_norm is not None:
|
||
|
if self.qk_norm == 'l2':
|
||
|
k = torch.nn.functional.normalize(k, dim=-1, p=2).to(v) #auto mixed precision type transfer is annoying.
|
||
|
q = torch.nn.functional.normalize(q, dim=-1, p=2).to(v)
|
||
|
elif self.qk_norm == 'sum':
|
||
|
q = sum_norm(q).to(v)
|
||
|
k = sum_norm(k).to(v)
|
||
|
|
||
|
if self.use_beta:
|
||
|
beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid()
|
||
|
else:
|
||
|
beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
|
||
|
state = past_key_values[self.layer_idx][-1] if use_cache else None
|
||
|
if mode == 'fused_recurrent':
|
||
|
o, recurrent_state = fused_recurrent_linear_attn_delta_rule(q, k, v, beta, state, output_final_state=use_cache)
|
||
|
elif mode == 'fused_chunk':
|
||
|
assert self.chunk_size in [16, 32, 64]
|
||
|
o, recurrent_state = fused_chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
|
||
|
elif mode == 'chunk':
|
||
|
assert self.chunk_size in [16, 32, 64]
|
||
|
o, recurrent_state = chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
|
||
|
else:
|
||
|
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
||
|
|
||
|
if past_key_values is not None:
|
||
|
if self.use_short_conv:
|
||
|
if self.share_conv_kernel:
|
||
|
state = (conv_state, recurrent_state)
|
||
|
else:
|
||
|
state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
|
||
|
else:
|
||
|
state = (recurrent_state,)
|
||
|
past_key_values.update(state, self.layer_idx)
|
||
|
|
||
|
o = rearrange(o, 'b h l d -> b l h d')
|
||
|
if self.use_gate:
|
||
|
g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads)
|
||
|
o = self.norm(o, g)
|
||
|
else:
|
||
|
o = self.norm(o)
|
||
|
o = rearrange(o, 'b l h d -> b l (h d)')
|
||
|
o = self.o_proj(o)
|
||
|
|
||
|
return o, None, past_key_values
|
||
|
|
||
|
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
|
||
|
param = next(self.parameters())
|
||
|
state = tuple()
|
||
|
if self.use_short_conv:
|
||
|
if self.share_conv_kernel:
|
||
|
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
|
||
|
else:
|
||
|
# for q/k/v each
|
||
|
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||
|
param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||
|
param.new_zeros(batch_size, self.value_dim, self.conv_size))
|
||
|
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
|
||
|
return state
|