157 lines
6.0 KiB
Python
157 lines
6.0 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from einops import rearrange
|
||
|
|
||
|
from fla.modules import RMSNorm
|
||
|
from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap,
|
||
|
HedgehogFeatureMap, T2RFeatureMap)
|
||
|
from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn,
|
||
|
fused_recurrent_linear_attn)
|
||
|
|
||
|
|
||
|
class LinearAttention(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_size: str = 1024,
|
||
|
expand_k: int = 1.0,
|
||
|
expand_v: int = 1.0,
|
||
|
num_heads: int = 8,
|
||
|
mode: str = 'chunk',
|
||
|
feature_map: str = 'elementwise_product',
|
||
|
tie_feature_map_qk: bool = False,
|
||
|
output_norm: str = 'rmsnorm',
|
||
|
norm_q: bool = False,
|
||
|
norm_k: bool = False,
|
||
|
# standard linear attention normalization
|
||
|
do_feature_map_norm: bool = False,
|
||
|
elementwise_affine: bool = True,
|
||
|
norm_eps: float = 1e-5,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__()
|
||
|
assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp',
|
||
|
'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`."
|
||
|
|
||
|
assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`."
|
||
|
|
||
|
self.hidden_size
|
||
|
self.mode = mode
|
||
|
self.key_dim = int(hidden_size * expand_k)
|
||
|
self.value_dim = int(hidden_size * expand_v)
|
||
|
self.num_heads = num_heads
|
||
|
|
||
|
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
||
|
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
||
|
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
||
|
|
||
|
self.head_qk_dim = self.key_dim // num_heads
|
||
|
self.head_v_dim = self.value_dim // num_heads
|
||
|
|
||
|
if feature_map == 'hedgehog':
|
||
|
if tie_feature_map_qk:
|
||
|
self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
|
||
|
else:
|
||
|
self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim)
|
||
|
self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
|
||
|
|
||
|
elif feature_map == 't2r':
|
||
|
if tie_feature_map_qk:
|
||
|
self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
|
||
|
else:
|
||
|
self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim)
|
||
|
self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
|
||
|
|
||
|
elif feature_map == 'elementwise_product':
|
||
|
if tie_feature_map_qk:
|
||
|
self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
|
||
|
else:
|
||
|
self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim)
|
||
|
self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
|
||
|
|
||
|
elif feature_map == 'dpfp':
|
||
|
self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim)
|
||
|
self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim)
|
||
|
|
||
|
elif feature_map == 'elu':
|
||
|
def elu(x):
|
||
|
return F.elu(x) + 1
|
||
|
self.feature_map_q = elu
|
||
|
self.feature_map_k = elu
|
||
|
|
||
|
elif feature_map == 'relu':
|
||
|
self.feature_map_q = nn.ReLU()
|
||
|
self.feature_map_k = nn.ReLU()
|
||
|
|
||
|
elif feature_map == 'identity':
|
||
|
self.feature_map_q = nn.Identity()
|
||
|
self.feature_map_k = nn.Identity()
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
self.do_feature_map_norm = do_feature_map_norm
|
||
|
if output_norm == 'rmsnorm':
|
||
|
self.norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
|
||
|
elif output_norm == 'identity':
|
||
|
self.norm = nn.Identity()
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||
|
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||
|
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||
|
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
||
|
|
||
|
self.norm_q = norm_q
|
||
|
self.norm_k = norm_k
|
||
|
|
||
|
self.apply(self._initialize_weights)
|
||
|
|
||
|
def _initialize_weights(self, module: nn.Module):
|
||
|
if getattr(module, "_is_hf_initialized", False):
|
||
|
return
|
||
|
if isinstance(module, nn.Linear):
|
||
|
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||
|
if module.bias is not None:
|
||
|
nn.init.zeros_(module.bias)
|
||
|
module._is_hf_initialized = True
|
||
|
|
||
|
def forward(self, x):
|
||
|
mode = self.mode
|
||
|
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||
|
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||
|
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||
|
q = self.feature_map_q(q)
|
||
|
k = self.feature_map_k(k)
|
||
|
if self.norm_q:
|
||
|
q = q / (q.sum(-1, keepdim=True) + 1e-4)
|
||
|
if self.norm_k:
|
||
|
k = k / (k.sum(-1, keepdim=True) + 1e-4)
|
||
|
|
||
|
if mode == 'chunk':
|
||
|
o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
|
||
|
elif mode == 'fused_chunk':
|
||
|
o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
|
||
|
elif mode == 'fused_recurrent':
|
||
|
o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
o = self.norm(o)
|
||
|
o = rearrange(o, 'b h n d -> b n (h d)')
|
||
|
o = self.o_proj(o)
|
||
|
return o
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
import torch
|
||
|
batch = 4
|
||
|
seq_len = 1024
|
||
|
hidden_size = 1024
|
||
|
x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
|
||
|
model = LinearAttention(hidden_size, feature_map='dplp').to(torch.bfloat16).cuda()
|
||
|
y = model(x)
|
||
|
print(y.shape)
|
||
|
y.sum().backward()
|
||
|
print(x.grad.shape)
|