138 lines
5.2 KiB
Python
Vendored
138 lines
5.2 KiB
Python
Vendored
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
|
|
from fla.modules.feature_map import RebasedFeatureMap
|
|
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
|
|
from fla.ops.rebased import parallel_rebased
|
|
|
|
|
|
class ReBasedLinearAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
l_max: int = 2048,
|
|
feature_dim: int = 16,
|
|
num_key_value_heads: int = 16,
|
|
num_heads: int = 16,
|
|
use_gamma: Optional[bool] = True,
|
|
use_beta: Optional[bool] = True,
|
|
normalize: Optional[bool] = True,
|
|
causal: bool = True,
|
|
eps: float = 1e-5,
|
|
mode: str = "parallel",
|
|
layer_idx: Optional[int] = None,
|
|
**kwargs
|
|
) -> ReBasedLinearAttention:
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.l_max = l_max
|
|
self.mode = mode
|
|
assert self.mode in ["fused_chunk", "parallel", 'chunk']
|
|
|
|
# linear attention
|
|
self.feature_dim = feature_dim
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.num_heads = num_heads
|
|
self.head_dim = self.hidden_size // self.num_key_value_heads
|
|
self.use_gamma = use_gamma
|
|
self.use_beta = use_beta
|
|
self.normalize = normalize
|
|
self.causal = causal
|
|
|
|
self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
|
|
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
self.dropout = nn.Identity()
|
|
self.eps = eps
|
|
|
|
self.apply(self._initialize_weights)
|
|
|
|
def _initialize_weights(self, module: nn.Module):
|
|
if getattr(module, "_is_hf_initialized", False):
|
|
return
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
|
if module.bias is not None:
|
|
nn.init.zeros_(module.bias)
|
|
module._is_hf_initialized = True
|
|
|
|
def forward(self, hidden_states: torch.Tensor, **kwargs):
|
|
mode = self.mode
|
|
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
|
q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
|
|
q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
|
|
if mode == "fused_chunk":
|
|
o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
|
|
elif mode == 'chunk':
|
|
o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
|
|
elif mode == 'parallel':
|
|
assert q.shape[-1] <= 128
|
|
o = parallel_rebased(q, k, v, self.eps, True, True)
|
|
o = rearrange(o, "b h l d -> b l (h d)")
|
|
o = self.o_proj(o)
|
|
o = self.dropout(o)
|
|
return o
|
|
|
|
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
|
|
def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
|
|
"""
|
|
x (torch.Tensor): tensor of shape (b, d, l)
|
|
y (torch.Tensor): tensor of shape (b, d, l)
|
|
"""
|
|
# hidden_states = hidden_states.transpose(1, 2)
|
|
b, l, _ = hidden_states.size()
|
|
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
|
|
|
q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
|
|
k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
|
|
v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
# Linear attention
|
|
q, k = self.feature_map(q), self.feature_map(k)
|
|
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
|
|
|
|
# Compute attention
|
|
if self.causal:
|
|
y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
|
|
else:
|
|
y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
|
|
y = rearrange(y, 'b h l d -> b l (h d)')
|
|
y = self.o_proj(y.to(hidden_states.dtype))
|
|
y = self.dropout(y)
|
|
return y.to(hidden_states.dtype)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
batch = 4
|
|
seq_len = 1024
|
|
hidden_size = 1024
|
|
dtype = torch.float32
|
|
x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
|
|
dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
|
|
model = ReBasedLinearAttention(hidden_size=hidden_size, mode='parallel').to(dtype).cuda()
|
|
|
|
y = model(x)
|
|
y.backward(dy, retain_graph=True)
|
|
x_grad, x.grad = x.grad, None
|
|
print(model.mode)
|
|
model.mode = 'fused_chunk'
|
|
y2 = model(x)
|
|
print(model.mode)
|
|
y2.backward(dy)
|
|
# assert y.allclose(y2, 0, 1e-4), breakpoint()
|
|
# assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
|
|
print("Pass")
|