This commit is contained in:
josc146
2024-05-28 22:35:47 +08:00
parent 3488d22d22
commit f05a4acb04
138 changed files with 29047 additions and 334 deletions

25
finetune/lora/v6/fla/layers/__init__.py vendored Normal file
View File

@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
from .abc import ABCAttention
from .based import BasedLinearAttention
from .delta_net import DeltaNet
from .gla import GatedLinearAttention
from .hgrn import HGRNAttention
from .hgrn2 import HGRN2Attention
from .linear_attn import LinearAttention
from .multiscale_retention import MultiScaleRetention
from .rebased import ReBasedLinearAttention
from .rwkv6 import RWKV6Attention
__all__ = [
'ABCAttention',
'BasedLinearAttention',
'DeltaNet',
'GatedLinearAttention',
'HGRNAttention',
'HGRN2Attention',
'LinearAttention',
'MultiScaleRetention',
'ReBasedLinearAttention',
'RWKV6Attention'
]

195
finetune/lora/v6/fla/layers/abc.py vendored Normal file
View File

@@ -0,0 +1,195 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import warnings
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from transformers.cache_utils import Cache
from fla.modules import (FusedRMSNormSwishGate, RMSNorm, RotaryEmbedding,
ShortConvolution)
from fla.modules.activations import swiglu, swish
from fla.modules.convolution import proj_then_conv1d
from fla.ops.abc.chunk import chunk_abc
class ABCAttention(nn.Module):
def __init__(
self,
hidden_size: int = 1024,
expand_k: float = 0.5,
expand_v: float = 1.0,
num_heads: int = 4,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
num_slots: Optional[int] = None,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
gate_low_rank_dim: int = 16,
gate_logit_normalizer: int = 16,
use_input_gate: bool = False,
use_output_gate: bool = True,
use_norm: bool = True,
clamp_min: Optional[float] = -32,
clamp_max: Optional[float] = 32,
layer_idx: Optional[int] = None,
**kwargs
) -> ABCAttention:
super().__init__()
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.key_dim = int(self.hidden_size * self.expand_k)
self.value_dim = int(self.hidden_size * self.expand_v)
self.head_k_dim = self.key_dim // self.num_heads
self.head_v_dim = self.value_dim // self.num_heads
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.gate_low_rank_dim = gate_low_rank_dim
self.gate_logit_normalizer = gate_logit_normalizer
self.use_input_gate = use_input_gate
self.use_output_gate = use_output_gate
self.use_norm = use_norm
if num_slots is None:
num_slots = self.head_k_dim
self.num_slots = num_slots
self.norm_eps = norm_eps
self.clamp_min = clamp_min
self.clamp_max = clamp_max
self.layer_idx = layer_idx
if layer_idx is None:
warnings.warn(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
if use_output_gate:
self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
if self.use_norm:
if self.use_output_gate:
self.g_norm = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
else:
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
if self.use_rope:
self.rotary = RotaryEmbedding(self.head_k_dim)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
if self.use_short_conv:
if self.share_conv_kernel:
hidden_states = self.h_conv1d(hidden_states)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
q = proj_then_conv1d(hidden_states, self.q_proj.weight, self.q_conv1d.weight, self.q_conv1d.bias)
k = proj_then_conv1d(hidden_states, self.k_proj.weight, self.k_conv1d.weight, self.k_conv1d.bias)
v = proj_then_conv1d(hidden_states, self.v_proj.weight, self.v_conv1d.weight, self.v_conv1d.bias)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
if self.use_input_gate:
q, k, v = map(lambda x: swish(x), (q, k, v))
if self.use_rope:
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
k = rearrange(k, '... (h d) -> ... h d', h=self.num_heads)
seqlen_offset = 0
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
q, k = self.rotary(q, k, seqlen_offset)
q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads)
k = rearrange(k, 'b n h d -> b h n d', h=self.num_heads)
else:
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
# [batch_size, n_heads, seq_len, num_slots]
s = rearrange(self.s_proj(hidden_states), 'b t (h m) -> b h t m', h=self.num_heads)
s = s.clamp_(self.clamp_min, self.clamp_max)
last_state = past_key_values[self.layer_idx] if use_cache else None
o, last_state = chunk_abc(q, k, v, s, initial_state=last_state, output_final_state=use_cache)
if past_key_values is not None and last_state is not None:
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = rearrange(o, 'b h t d -> b t h d')
if self.use_norm and not self.use_output_gate:
o = self.g_norm(o)
elif self.use_output_gate:
g = rearrange(self.g_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_heads)
o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
o = rearrange(o, 'b t h d -> b t (h d)')
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots),
param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim))
return state
def state_size(self, sequence_length: int = 2048):
return self.num_heads * self.key_dim * self.head_v_dim

126
finetune/lora/v6/fla/layers/based.py vendored Normal file
View File

@@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
"""
Linear attention in Based.
https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
"""
import torch
import torch.nn as nn
from einops import rearrange
from fla.modules.feature_map import TaylorFeatureMap
from fla.ops.based import parallel_based
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
class BasedLinearAttention(nn.Module):
def __init__(
self,
hidden_size: int,
l_max: int = 2048,
feature_dim: int = 16,
num_key_value_heads: int = 12,
num_heads: int = 12,
feature_name: str = "taylor_exp",
eps: float = 1e-12,
causal: bool = True,
mode: str = "parallel",
):
super().__init__()
self.hidden_size
self.l_max = l_max
self.mode = mode
assert self.mode in ["fused_chunk", "parallel", 'chunk']
# linear attention
self.feature_name = feature_name
self.feature_dim = feature_dim
self.num_key_value_heads = num_key_value_heads
self.num_heads = num_heads
self.head_dim = self.hidden_size // self.num_key_value_heads
self.causal = causal
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.dropout = nn.Identity()
self.feature_map = TaylorFeatureMap(feature_dim)
self.eps = eps
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, hidden_states: torch.Tensor, **kwargs):
mode = self.mode
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
if mode == "fused_chunk":
q, k = self.feature_map(q), self.feature_map(k)
o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'chunk':
q, k = self.feature_map(q), self.feature_map(k)
o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'parallel':
assert q.shape[-1] <= 128
o = parallel_based(q, k, v, True, True)
o = rearrange(o, "b h l d -> b l (h d)")
o = self.o_proj(o)
o = self.dropout(o)
return o
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
"""
x (torch.Tensor): tensor of shape (b, d, l)
y (torch.Tensor): tensor of shape (b, d, l)
"""
# hidden_states = hidden_states.transpose(1, 2)
b, l, _ = hidden_states.size()
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Linear attention
q, k = self.feature_map(q), self.feature_map(k)
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
# Compute attention
if self.causal:
y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
else:
y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
y = rearrange(y, 'b h l d -> b l (h d)')
y = self.o_proj(y.to(hidden_states.dtype))
y = self.dropout(y)
return y.to(hidden_states.dtype)
if __name__ == '__main__':
batch = 4
seq_len = 1024
hidden_size = 1024
dtype = torch.float32
x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
model = BasedLinearAttention(hidden_size, mode='chunk').to(dtype).cuda()
y = model(x)
y.backward(dy, retain_graph=True)
x_grad, x.grad = x.grad, None
y2 = model.forward_reference(x)
y2.backward(dy)
assert y.allclose(y2, 0, 1e-4), breakpoint()
assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
print("Pass")

254
finetune/lora/v6/fla/layers/delta_net.py vendored Normal file
View File

@@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
# Sect4.2 of Linear Transformers Are Secretly Fast Weight Programmers https://arxiv.org/abs/2102.11174
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from transformers.cache_utils import Cache
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution, LayerNorm
from fla.modules.rotary import RotaryEmbedding
from fla.ops.delta_rule import (fused_chunk_delta_rule,
fused_recurrent_linear_attn_delta_rule,
chunk_delta_rule)
from torch.nn import functional as F
def simple_norm(x):
return (F.normalize(x, dim=-1) * x.shape[-1] ** 0.5).to(x)
# @torch.jit.script
def elu_p1(x):
return (F.elu(x, 1., False) + 1.).to(x)
# @torch.jit.script
def sum_norm(x):
return (x / x.sum(-1, keepdim=True)).to(x)
# @torch.jit.script
def elu_norm(x):
dtype = x.dtype
x = F.elu(x, 1., False) + 1.
return (x / x.sum(-1, keepdim=True)).to(dtype)
# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
class DeltaNet(nn.Module):
def __init__(
self,
d_model: int = None,
hidden_size: int = 1024,
expand_k: float = 1.0,
expand_v: float = 1.0,
num_heads: int = 4,
mode: str = 'fused_chunk',
chunk_size: int = 16,
use_beta: bool = True,
use_gate: bool = True,
use_rope: bool = False,
use_output_norm: bool = True,
use_elu: bool = False,
use_short_conv: bool = True,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = False,
layer_idx: int = None,
qk_activation: str = 'silu',
qk_norm: str = None,
save_memory: str = False,
**kwargs
) -> DeltaNet:
super().__init__()
self.mode = mode
self.qk_activation = qk_activation
self.qk_norm = qk_norm
assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
assert self.qk_norm in ['l2', 'sum']
if d_model is not None:
hidden_size = d_model
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.chunk_size = chunk_size
self.use_gate = use_gate
self.use_output_norm = use_output_norm
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.layer_idx = layer_idx
self.silu = torch.nn.SiLU()
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
self.use_beta = use_beta
self.use_elu = use_elu
if self.use_beta:
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation=None)
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
if use_gate:
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
if self.use_gate:
self.norm = FusedRMSNormSwishGate(self.head_v_dim)
else:
self.norm = RMSNorm(self.head_v_dim)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# change to inference mode.
mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if attention_mask is not None:
if attention_mask.shape[-1] != hidden_states.shape[-2]:
attention_mask = attention_mask[:, -1:]
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_k = last_state[1] if use_cache else None
conv_state_v = last_state[2] if use_cache else None
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = self.q_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
k = self.k_conv1d(k, attention_mask, conv_state_k)
v = self.v_conv1d(v, attention_mask, conv_state_v)
else:
q = (self.q_proj(hidden_states))
k = (self.k_proj(hidden_states))
v = self.silu(self.v_proj(hidden_states))
# dealing with left-padding
if attention_mask is not None:
v = v.mul_(attention_mask.unsqueeze(-1))
q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v))
if self.qk_activation != 'silu':
if self.qk_activation == 'relu':
q, k = q.relu(), k.relu()
elif self.qk_activation == 'elu':
q, k = elu_p1(q), elu_p1(k)
elif self.qk_activation == 'identity':
pass
else:
raise NotImplementedError
if self.qk_norm is not None:
if self.qk_norm == 'l2':
k = torch.nn.functional.normalize(k, dim=-1, p=2).to(v) #auto mixed precision type transfer is annoying.
q = torch.nn.functional.normalize(q, dim=-1, p=2).to(v)
elif self.qk_norm == 'sum':
q = sum_norm(q).to(v)
k = sum_norm(k).to(v)
if self.use_beta:
beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid()
else:
beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
state = past_key_values[self.layer_idx][-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_linear_attn_delta_rule(q, k, v, beta, state, output_final_state=use_cache)
elif mode == 'fused_chunk':
assert self.chunk_size in [16, 32, 64]
o, recurrent_state = fused_chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
elif mode == 'chunk':
assert self.chunk_size in [16, 32, 64]
o, recurrent_state = chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
state = (conv_state, recurrent_state)
else:
state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
else:
state = (recurrent_state,)
past_key_values.update(state, self.layer_idx)
o = rearrange(o, 'b h l d -> b l h d')
if self.use_gate:
g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads)
o = self.norm(o, g)
else:
o = self.norm(o)
o = rearrange(o, 'b l h d -> b l (h d)')
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
# for q/k/v each
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.value_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
return state

234
finetune/lora/v6/fla/layers/gated_abc.py vendored Normal file
View File

@@ -0,0 +1,234 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import warnings
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers.cache_utils import Cache
from fla.modules import (FusedRMSNormSwishGateLinear, RMSNormLinear,
RotaryEmbedding, ShortConvolution)
from fla.modules.activations import ACT2FN, swiglu_linear, swish
from fla.ops.abc.chunk_gate import chunk_gated_abc
class GatedABCAttention(nn.Module):
def __init__(
self,
hidden_size: int = 1024,
expand_k: float = 1.,
expand_v: float = 1.,
num_heads: int = 4,
num_kv_heads: Optional[int] = None,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
num_slots: Optional[int] = None,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
gate_low_rank_dim: Optional[int] = None,
gate_logit_normalizer: int = 16,
feature_map: str = 'swish',
use_rope: bool = False,
use_output_gate: bool = False,
use_norm: bool = True,
layer_idx: Optional[int] = None,
**kwargs
) -> GatedABCAttention:
super().__init__()
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.key_dim_per_group = self.key_dim // self.num_kv_groups
self.value_dim_per_group = self.value_dim // self.num_kv_groups
self.head_k_dim = self.key_dim // self.num_heads
self.head_v_dim = self.value_dim // self.num_heads
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
if gate_low_rank_dim is None:
gate_low_rank_dim = self.hidden_size // 16
self.gate_low_rank_dim = gate_low_rank_dim
self.gate_logit_normalizer = gate_logit_normalizer
self.feature_map = feature_map
self.use_rope = use_rope
self.use_output_gate = use_output_gate
self.use_norm = use_norm
if num_slots is None:
num_slots = self.head_k_dim
self.num_slots = num_slots
self.layer_idx = layer_idx
if layer_idx is None:
warnings.warn(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
if use_output_gate:
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
if self.use_norm:
if self.use_output_gate:
self.g_norm = FusedRMSNormSwishGateLinear(self.hidden_size, elementwise_affine, norm_eps)
else:
self.g_norm = RMSNormLinear(self.hidden_size, elementwise_affine, norm_eps)
self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
if self.use_rope:
self.rotary = RotaryEmbedding(self.head_k_dim)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
lower_bound: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_k = last_state[1] if use_cache else None
conv_state_v = last_state[2] if use_cache else None
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
k = self.k_conv1d(k, attention_mask, conv_state_k)
v = self.v_conv1d(v, attention_mask, conv_state_v)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
f = self.f_proj(hidden_states)
if self.use_rope:
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
seqlen_offset = 0
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
q, k = self.rotary(q, k, seqlen_offset)
q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads)
k = rearrange(k, 'b n h d -> b h n d', h=self.num_kv_heads)
else:
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
if self.num_kv_groups > 1:
k = repeat(k, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups)
else:
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_kv_heads)
if self.num_kv_groups > 1:
v = repeat(v, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups)
f = repeat(f, 'b n (h m) -> b (h g) n m', h=self.num_kv_heads, g=self.num_kv_groups)
else:
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_kv_heads)
f = rearrange(f, 'b n (h m) -> b h n m', h=self.num_kv_heads)
if self.feature_map is not None:
q, k, v = map(lambda x: ACT2FN[self.feature_map](x), (q, k, v))
f = F.logsigmoid(f) / self.gate_logit_normalizer
s = (1 - f.exp()).to(f.dtype)
# dealing with left-padding
if attention_mask is not None:
s = s.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1))
v = v.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1))
recurrent_state = last_state[-2:] if use_cache else None
o, recurrent_state = chunk_gated_abc(q, k, v, s, f,
initial_state=recurrent_state,
output_final_state=use_cache)
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state,) + recurrent_state
else:
last_state = (conv_state_q, conv_state_k, conv_state_v) + recurrent_state
else:
last_state = recurrent_state
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = rearrange(o, 'b h t d -> b t (h d)')
if self.use_norm and not self.use_output_gate:
o = swish(o)
o = self.g_norm(o, self.o_proj.weight, self.o_proj.bias)
elif self.use_output_gate and not self.use_norm:
o = swiglu_linear(self.g_proj(hidden_states), o, self.o_proj.weight, self.o_proj.bias)
elif self.use_output_gate and self.use_norm:
o = self.g_norm(o, self.g_proj(hidden_states), self.o_proj.weight, self.o_proj.bias)
else:
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.value_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots),
param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim))
return state
def state_size(self, sequence_length: int = 2048):
return self.num_heads * self.key_dim * self.head_v_dim

268
finetune/lora/v6/fla/layers/gla.py vendored Normal file
View File

@@ -0,0 +1,268 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers.cache_utils import Cache
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
from fla.modules.activations import ACT2FN
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
class GatedLinearAttention(nn.Module):
r"""
The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
Args:
mode (str, Optional):
Which GLA kernel to use.
Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
Default: `chunk`.
hidden_size (int, Optional):
The hidden size of the input. Default: 1024.
expand_k (float, Optional):
The expansion ratio for the key dim. Default: 0.5.
expand_v (float, Optional):
The expansion ratio for the value dim. Default: 1.0.
num_heads (int, Optional):
The number of heads. Default: 4.
num_kv_heads (int, Optional):
The number of key/value heads, used for MQA. Default: None.
feature_map (str, Optional):
Feature map function applied to queries/keys. Default: None.
use_short_conv (bool, Optional):
Whether to use short convolutions. Default: `False`.
conv_size (int, Optional):
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
conv_bias (bool, Optional):
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
share_conv_kernel (bool, Optional):
Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`.
use_output_gate (bool, Optional):
Whether to use output gate. Default: `True`.
gate_fn (str, Optional):
The activation function for the output gate. Default: `swish`.
elementwise_affine (bool, Optional):
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
norm_eps (float, Optional):
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
gate_logit_normalizer (int, Optional):
The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
gate_low_rank_dim (int, Optional):
The low rank dim for the gate projection. Default: 16.
clamp_min (float, Optional):
The minimum value for the gate logits. Default: None.
fuse_norm (bool, Optional):
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
layer_idx (int, Optional):
The index of the layer. Default: None.
"""
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
expand_k: float = 0.5,
expand_v: float = 1.0,
num_heads: int = 4,
num_kv_heads: Optional[int] = None,
feature_map: Optional[str] = None,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
use_output_gate: bool = True,
gate_fn: str = 'swish',
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
gate_logit_normalizer: int = 16,
gate_low_rank_dim: int = 16,
clamp_min: Optional[float] = None,
fuse_norm: bool = True,
layer_idx: int = None,
) -> GatedLinearAttention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.use_output_gate = use_output_gate
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.key_dim_per_group = self.key_dim // self.num_kv_groups
self.value_dim_per_group = self.value_dim // self.num_kv_groups
self.clamp_min = clamp_min
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
if self.use_output_gate:
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
if gate_fn == 'swish' and fuse_norm and use_output_gate:
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
self.fuse_norm_and_gate = True
else:
self.fuse_norm_and_gate = False
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
self.gate_fn = ACT2FN[gate_fn]
self.gate_logit_normalizer = gate_logit_normalizer
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_k = last_state[1] if use_cache else None
conv_state_v = last_state[2] if use_cache else None
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
k = self.k_conv1d(k, attention_mask, conv_state_k)
v = self.v_conv1d(v, attention_mask, conv_state_v)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
gk = self.gk_proj(hidden_states)
if self.feature_map_fn is not None:
q, k = map(self.feature_map_fn, (q, k))
# dealing with left-padding
if attention_mask is not None:
v = v.mul_(attention_mask.unsqueeze(-1))
q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads)
if self.num_kv_groups > 1:
k, v, gk = (repeat(x, 'b l (h d) -> b (h g) l d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v, gk))
else:
k, v, gk = (rearrange(x, 'b l (h d) -> b h l d', h=self.num_kv_heads) for x in (k, v, gk))
gk = F.logsigmoid(gk) / self.gate_logit_normalizer
if self.clamp_min is not None:
gk = torch.clamp_min(gk, self.clamp_min)
recurrent_state = last_state[-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'fused_chunk':
o, recurrent_state = fused_chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'chunk':
o, recurrent_state = chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state, recurrent_state)
else:
last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
else:
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = rearrange(o, 'b h l d -> b l h d')
if self.use_output_gate:
g = self.g_proj(hidden_states)
if self.fuse_norm_and_gate:
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
o = self.g_norm_swish_gate(o, g)
o = rearrange(o, 'b l h d -> b l (h d)')
else:
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
o = o * self.gate_fn(g)
else:
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.value_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.key_dim * self.head_v_dim
for module in self.children():
if isinstance(module, ShortConvolution):
state_size += module.state_size
return state_size

165
finetune/lora/v6/fla/layers/hgrn.py vendored Normal file
View File

@@ -0,0 +1,165 @@
# -*- coding: utf-8 -*-
# "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.cache_utils import Cache
from fla.modules import FusedRMSNormSwishGate, ShortConvolution
from fla.modules.activations import swiglu
from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
class HGRNAttention(nn.Module):
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
num_heads: Optional[int] = None,
expand_ratio: Optional[int] = 1,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
layer_idx: int = None
) -> HGRNAttention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
self.num_heads = num_heads
self.expand_ratio = expand_ratio
self.input_dim = int(hidden_size * expand_ratio)
self.head_dim = self.input_dim // self.num_heads
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.hidden_size % num_heads == 0, f"hidden size must be divisible by num_heads of {num_heads}"
self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.g_norm = FusedRMSNormSwishGate(self.input_dim, elementwise_affine, norm_eps)
self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
lower_bound: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
i = self.i_proj(hidden_states)
f = self.f_proj(hidden_states)
else:
conv_state_i = last_state[2] if use_cache else None
conv_state_f = last_state[1] if use_cache else None
i = self.i_conv1d(self.i_proj(hidden_states), attention_mask, conv_state_i)
f = self.f_conv1d(self.f_proj(hidden_states), attention_mask, conv_state_f)
else:
i = self.i_proj(hidden_states)
f = self.f_proj(hidden_states)
# the lower bound for the first layer is zero
if lower_bound is None or self.layer_idx == 0:
i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
else:
g = lower_bound + (1 - lower_bound) * f.sigmoid()
i, f = swiglu(i, 1 - g), g.log()
# dealing with left-padding
if attention_mask is not None:
i = i.mul_(attention_mask.unsqueeze(-1))
i, f = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (i, f))
recurrent_state = last_state[-1] if use_cache else None
if mode == 'chunk':
o, recurrent_state = chunk_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state, recurrent_state)
else:
last_state = (conv_state_i, conv_state_f, recurrent_state)
else:
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, i.shape[2])
o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)'))
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),
param.new_zeros(batch_size, self.hidden_size, self.conv_size),
param.new_zeros(batch_size, self.hidden_size, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.hidden_size
for module in self.children():
if isinstance(module, ShortConvolution):
state_size += module.state_size
return state_size

186
finetune/lora/v6/fla/layers/hgrn2.py vendored Normal file
View File

@@ -0,0 +1,186 @@
# -*- coding: utf-8 -*-
# "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.cache_utils import Cache
from fla.modules import RMSNorm, ShortConvolution
from fla.modules.activations import swish
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
class HGRN2Attention(nn.Module):
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
num_heads: Optional[int] = None,
expand_ratio: Optional[int] = 128,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
layer_idx: int = None
) -> HGRN2Attention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
if expand_ratio is None and num_heads is not None:
expand_ratio = hidden_size // num_heads
elif expand_ratio is not None and num_heads is None:
num_heads = hidden_size // expand_ratio
else:
raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
self.num_heads = num_heads
self.expand_ratio = expand_ratio
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.forget_dim = int(self.num_heads * self.expand_ratio)
self.input_dim = hidden_size
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
self.head_f_dim = self.expand_ratio
self.head_i_dim = self.hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu')
self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu')
self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, norm_eps)
self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
lower_bound: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
f = self.f_proj(hidden_states)
i = self.i_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_f = last_state[1] if use_cache else None
conv_state_i = last_state[2] if use_cache else None
q = self.q_proj(hidden_states)
f = self.f_proj(hidden_states)
i = self.i_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
f = self.f_conv1d(f, attention_mask, conv_state_f)
i = self.i_conv1d(i, attention_mask, conv_state_i)
else:
q = self.q_proj(hidden_states)
f = self.f_proj(hidden_states)
i = self.i_proj(hidden_states)
# dealing with left-padding
if attention_mask is not None:
i = i.mul_(attention_mask.unsqueeze(-1))
q = swish(q)
# the lower bound for the first layer is zero
if lower_bound is None or self.layer_idx == 0:
k, g = 1 - f.sigmoid(), F.logsigmoid(f)
else:
g = lower_bound + (1 - lower_bound) * f.sigmoid()
k, g = 1 - g, g.log()
q, k, i, g = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, i, g))
recurrent_state = last_state[-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'fused_chunk':
o, recurrent_state = fused_chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
elif mode == 'chunk':
o, recurrent_state = chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state, recurrent_state)
else:
last_state = (conv_state_q, conv_state_f, conv_state_i, recurrent_state)
else:
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = self.g_norm(rearrange(o, 'b h l d -> b l (h d)'))
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.forget_dim, self.conv_size),
param.new_zeros(batch_size, self.forget_dim, self.conv_size),
param.new_zeros(batch_size, self.input_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_f_dim, self.head_i_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.forget_dim * self.head_i_dim
for module in self.children():
if isinstance(module, ShortConvolution):
state_size += module.state_size
return state_size

View File

@@ -0,0 +1,156 @@
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from fla.modules import RMSNorm
from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap,
HedgehogFeatureMap, T2RFeatureMap)
from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn,
fused_recurrent_linear_attn)
class LinearAttention(nn.Module):
def __init__(
self,
hidden_size: str = 1024,
expand_k: int = 1.0,
expand_v: int = 1.0,
num_heads: int = 8,
mode: str = 'chunk',
feature_map: str = 'elementwise_product',
tie_feature_map_qk: bool = False,
output_norm: str = 'rmsnorm',
norm_q: bool = False,
norm_k: bool = False,
# standard linear attention normalization
do_feature_map_norm: bool = False,
elementwise_affine: bool = True,
norm_eps: float = 1e-5,
**kwargs,
):
super().__init__()
assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp',
'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`."
assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`."
self.hidden_size
self.mode = mode
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.num_heads = num_heads
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
if feature_map == 'hedgehog':
if tie_feature_map_qk:
self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
else:
self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 't2r':
if tie_feature_map_qk:
self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
else:
self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 'elementwise_product':
if tie_feature_map_qk:
self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
else:
self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 'dpfp':
self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim)
self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim)
elif feature_map == 'elu':
def elu(x):
return F.elu(x) + 1
self.feature_map_q = elu
self.feature_map_k = elu
elif feature_map == 'relu':
self.feature_map_q = nn.ReLU()
self.feature_map_k = nn.ReLU()
elif feature_map == 'identity':
self.feature_map_q = nn.Identity()
self.feature_map_k = nn.Identity()
else:
raise NotImplementedError
self.do_feature_map_norm = do_feature_map_norm
if output_norm == 'rmsnorm':
self.norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
elif output_norm == 'identity':
self.norm = nn.Identity()
else:
raise NotImplementedError
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
self.norm_q = norm_q
self.norm_k = norm_k
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, x):
mode = self.mode
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
q = self.feature_map_q(q)
k = self.feature_map_k(k)
if self.norm_q:
q = q / (q.sum(-1, keepdim=True) + 1e-4)
if self.norm_k:
k = k / (k.sum(-1, keepdim=True) + 1e-4)
if mode == 'chunk':
o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
elif mode == 'fused_chunk':
o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
elif mode == 'fused_recurrent':
o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
else:
raise NotImplementedError
o = self.norm(o)
o = rearrange(o, 'b h n d -> b n (h d)')
o = self.o_proj(o)
return o
if __name__ == '__main__':
import torch
batch = 4
seq_len = 1024
hidden_size = 1024
x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
model = LinearAttention(hidden_size, feature_map='dplp').to(torch.bfloat16).cuda()
y = model(x)
print(y.shape)
y.sum().backward()
print(x.grad.shape)

View File

@@ -0,0 +1,271 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange, repeat
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
from fla.modules.rotary import RotaryEmbedding
from fla.ops.retention import (chunk_retention, fused_chunk_retention,
fused_recurrent_retention, parallel_retention)
class MultiScaleRetention(nn.Module):
r"""
The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa
Args:
mode (str, Optional):
Which Retention kernel to use.
Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
Default: `fused_chunk`.
hidden_size (int, Optional):
The hidden size of the input. Default: 1024.
expand_k (float, Optional):
The expansion ratio for the key dim. Default: 1.0.
expand_v (float, Optional):
The expansion ratio for the value dim. Default: 2.0.
num_heads (int, Optional):
The number of heads. Default: 8.
num_kv_heads (int, Optional):
The number of key/value heads, used for MQA. Default: None.
feature_map (str, Optional):
Feature map function applied to queries/keys. Default: None.
use_short_conv (bool, Optional):
Whether to use short convolutions. Default: `False`.
conv_size (int, Optional):
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
conv_bias (bool, Optional):
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
share_conv_kernel (bool, Optional):
Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`.
use_output_gate (bool, Optional):
Whether to use output gate. Default: `True`.
gate_fn (str, Optional):
The activation function for the output gate. Default: `swish`.
elementwise_affine (bool, Optional):
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
norm_eps (float, Optional):
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
fuse_norm (bool, Optional):
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
layer_idx (int, Optional):
The index of the layer. Default: None.
"""
def __init__(
self,
mode: str = 'fused_chunk',
hidden_size: int = 1024,
expand_k: float = 1.0,
expand_v: float = 2.0,
num_heads: int = 8,
num_kv_heads: Optional[int] = None,
feature_map: Optional[str] = None,
use_short_conv: bool = False,
conv_size: int = 4,
conv_bias: bool = False,
share_conv_kernel: bool = True,
use_output_gate: bool = True,
gate_fn: str = 'swish',
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
fuse_norm: bool = True,
layer_idx: int = None,
**kwargs
) -> MultiScaleRetention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
self.use_short_conv = use_short_conv
self.conv_size = conv_size
self.conv_bias = conv_bias
self.share_conv_kernel = share_conv_kernel
self.use_output_gate = use_output_gate
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.key_dim_per_group = self.key_dim // self.num_kv_groups
self.value_dim_per_group = self.value_dim // self.num_kv_groups
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
if self.use_output_gate:
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
if use_short_conv:
self.conv_size = conv_size
if share_conv_kernel:
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
else:
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
if gate_fn == 'swish' and fuse_norm and use_output_gate:
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
self.fuse_norm_and_gate = True
else:
self.fuse_norm_and_gate = False
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
self.gate_fn = ACT2FN[gate_fn]
# TODO: fix this issue
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
# Ideally, we would want to support arbitrary d_head_qk
assert self.head_qk_dim <= 256, "head_qk_dim must be less than or equal to 256"
self.rotary = RotaryEmbedding(dim=self.head_qk_dim)
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
last_state = past_key_values[self.layer_idx] if use_cache else None
if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_k = last_state[1] if use_cache else None
conv_state_v = last_state[2] if use_cache else None
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = self.q_conv1d(q, attention_mask, conv_state_q)
k = self.k_conv1d(k, attention_mask, conv_state_k)
v = self.v_conv1d(v, attention_mask, conv_state_v)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# dealing with left-padding
if attention_mask is not None:
v = v.mul_(attention_mask.unsqueeze(-1))
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
if self.feature_map_fn is not None:
q, k = map(self.feature_map_fn, (q, k))
seqlen_offset, max_seqlen = 0, None
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
max_seqlen = q.shape[1] + seqlen_offset
if attention_mask is not None:
# to deliminate the offsets of padding tokens
seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
max_seqlen = q.shape[1] + max(seqlen_offset)
q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
q = q.transpose(1, 2)
if self.num_kv_groups > 1:
k = repeat(k, 'b t h d -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
v = repeat(v, 'b t (h d) -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
else:
k, v = rearrange(k, 'b t h d -> b h t d'), rearrange(v, 'b t (h d) -> b h t d', h=self.num_kv_heads)
state = last_state[-1] if use_cache else None
if mode == 'chunk':
o, recurrent_state = chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache)
elif mode == 'fused_chunk':
o, recurrent_state = fused_chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache)
elif mode == 'parallel':
o, recurrent_state = parallel_retention(q, k, v, initial_state=state, output_final_state=use_cache)
elif mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_retention(q, k, v, initial_state=state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
if self.use_short_conv:
if self.share_conv_kernel:
last_state = (conv_state, recurrent_state)
else:
last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
else:
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, q.shape[2])
o = rearrange(o, 'b h l d -> b l h d')
if self.use_output_gate:
g = self.g_proj(hidden_states)
if self.fuse_norm_and_gate:
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
o = self.g_norm_swish_gate(o, g)
o = rearrange(o, 'b l h d -> b l (h d)')
else:
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
o = o * self.gate_fn(g)
else:
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = tuple()
if self.use_short_conv:
if self.share_conv_kernel:
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
else:
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.key_dim, self.conv_size),
param.new_zeros(batch_size, self.value_dim, self.conv_size))
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.key_dim * self.head_v_dim
for module in self.children():
if isinstance(module, ShortConvolution):
state_size += module.state_size
return state_size

137
finetune/lora/v6/fla/layers/rebased.py vendored Normal file
View File

@@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
"""
https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from fla.modules.feature_map import RebasedFeatureMap
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
from fla.ops.rebased import parallel_rebased
class ReBasedLinearAttention(nn.Module):
def __init__(
self,
hidden_size: int,
l_max: int = 2048,
feature_dim: int = 16,
num_key_value_heads: int = 16,
num_heads: int = 16,
use_gamma: Optional[bool] = True,
use_beta: Optional[bool] = True,
normalize: Optional[bool] = True,
causal: bool = True,
eps: float = 1e-5,
mode: str = "parallel",
layer_idx: Optional[int] = None,
**kwargs
) -> ReBasedLinearAttention:
super().__init__()
self.hidden_size = hidden_size
self.l_max = l_max
self.mode = mode
assert self.mode in ["fused_chunk", "parallel", 'chunk']
# linear attention
self.feature_dim = feature_dim
self.num_key_value_heads = num_key_value_heads
self.num_heads = num_heads
self.head_dim = self.hidden_size // self.num_key_value_heads
self.use_gamma = use_gamma
self.use_beta = use_beta
self.normalize = normalize
self.causal = causal
self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.dropout = nn.Identity()
self.eps = eps
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, hidden_states: torch.Tensor, **kwargs):
mode = self.mode
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
if mode == "fused_chunk":
o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'chunk':
o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
elif mode == 'parallel':
assert q.shape[-1] <= 128
o = parallel_rebased(q, k, v, self.eps, True, True)
o = rearrange(o, "b h l d -> b l (h d)")
o = self.o_proj(o)
o = self.dropout(o)
return o
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
"""
x (torch.Tensor): tensor of shape (b, d, l)
y (torch.Tensor): tensor of shape (b, d, l)
"""
# hidden_states = hidden_states.transpose(1, 2)
b, l, _ = hidden_states.size()
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Linear attention
q, k = self.feature_map(q), self.feature_map(k)
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
# Compute attention
if self.causal:
y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
else:
y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
y = rearrange(y, 'b h l d -> b l (h d)')
y = self.o_proj(y.to(hidden_states.dtype))
y = self.dropout(y)
return y.to(hidden_states.dtype)
if __name__ == '__main__':
batch = 4
seq_len = 1024
hidden_size = 1024
dtype = torch.float32
x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
model = ReBasedLinearAttention(hidden_size=hidden_size, mode='parallel').to(dtype).cuda()
y = model(x)
y.backward(dy, retain_graph=True)
x_grad, x.grad = x.grad, None
print(model.mode)
model.mode = 'fused_chunk'
y2 = model(x)
print(model.mode)
y2.backward(dy)
# assert y.allclose(y2, 0, 1e-4), breakpoint()
# assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
print("Pass")

264
finetune/lora/v6/fla/layers/rwkv6.py vendored Normal file
View File

@@ -0,0 +1,264 @@
# -*- coding: utf-8 -*-
# "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from fla.modules import FusedLayerNormSwishGate, LayerNorm
from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
class RWKV6Attention(nn.Module):
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
expand_k: float = 0.5,
expand_v: float = 1.0,
num_heads: int = 4,
gate_fn: str = 'swish',
proj_low_rank_dim: int = 32,
gate_low_rank_dim: int = 64,
fuse_norm: bool = True,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
layer_idx: int = None,
**kwargs
) -> RWKV6Attention:
super().__init__()
self.mode = mode
self.hidden_size = hidden_size
self.expand_k = expand_k
self.expand_v = expand_v
self.num_heads = num_heads
self.proj_low_rank_dim = proj_low_rank_dim
self.gate_low_rank_dim = gate_low_rank_dim
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.layer_idx = layer_idx
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.x_proj = nn.Sequential(
LerpLinear(hidden_size, proj_low_rank_dim * 5),
nn.Tanh(),
nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=True)
)
self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_qk_dim))
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
if gate_fn == 'swish' and fuse_norm:
self.g_norm_swish_gate = FusedLayerNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
self.fuse_norm_and_gate = True
else:
self.fuse_norm_and_gate = False
self.g_norm = LayerNorm(self.head_v_dim, elementwise_affine, norm_eps)
self.gate_fn = ACT2FN[gate_fn]
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
if isinstance(module, nn.Parameter):
nn.init.xavier_uniform_(module, gain=2 ** -2.5)
module._is_hf_initialized = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
batch_size, seq_len, hidden_size = hidden_states.size()
# launching the triton kernel for just one token will actually be slower
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
delta = self.time_shift(hidden_states) - hidden_states
x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
r, w, k, v, g = torch.einsum('b l n r, n r d-> b l n d',
self.x_proj[1](x),
self.x_proj[2].weight.view(5, -1, hidden_size)).unbind(-2)
r = self.r_proj(hidden_states, r, delta)
w = self.w_proj(hidden_states, w, delta)
k = self.k_proj(hidden_states, k, delta)
v = self.v_proj(hidden_states, v, delta)
g = self.g_proj(hidden_states, g, delta)
# dealing with left-padding
if attention_mask is not None:
v = v.mul_(attention_mask.unsqueeze(-1))
r, w, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (r, w, k, v))
w = -torch.exp(w)
u = self.bonus
last_state = past_key_values[self.layer_idx] if use_cache else None
state = last_state[-1] if use_cache else None
if mode == 'fused_recurrent':
o, recurrent_state = fused_recurrent_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache)
elif mode == 'chunk':
o, recurrent_state = chunk_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
if past_key_values is not None:
past_key_values.update((recurrent_state,), self.layer_idx, r.shape[2])
o = rearrange(o, 'b h l d -> b l h d')
if self.fuse_norm_and_gate:
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
o = self.g_norm_swish_gate(o, g)
o = rearrange(o, 'b l h d -> b l (h d)')
else:
o = self.g_norm(o)
o = rearrange(o, 'b l h d -> b l (h d)')
o = o * self.gate_fn(g)
o = self.o_proj(o)
return o, None, past_key_values
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
param = next(self.parameters())
state = (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
return state
def state_size(self, **kwargs) -> int:
state_size = self.key_dim * self.head_v_dim
return state_size
class LoRA(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
low_rank_dim: int,
bias: Optional[bool] = True
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.low_rank_dim = low_rank_dim
self.bias = bias
self.lora = nn.Sequential(
nn.Linear(input_dim, low_rank_dim, bias=False),
nn.Tanh(),
nn.Linear(low_rank_dim, output_dim, bias=bias)
)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}("
s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
if not self.bias:
s += f", bias={self.bias}"
s += ")"
return s
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lora(x)
class LerpLinear(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
low_rank_dim: Optional[int] = None
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.low_rank_dim = low_rank_dim
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
if low_rank_dim is None:
self.linear = nn.Linear(input_dim, output_dim, bias=False)
else:
self.linear = LoRA(input_dim, output_dim, low_rank_dim)
self.mu = nn.Parameter(torch.zeros(input_dim))
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
if self.low_rank_dim is not None:
s += f", low_rank_dim={self.low_rank_dim}"
s += ")"
return s
def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
if delta is None:
shifted = self.time_shift(x)
if len(shifted.shape) == 2:
shifted = shifted.unsqueeze(1)
delta = shifted - x
return self.linear(x + delta * self.mu)
class DDLerpLinear(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
low_rank_dim: Optional[int] = None
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.low_rank_dim = low_rank_dim
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
if low_rank_dim is None:
self.linear = nn.Linear(input_dim, output_dim, bias=False)
else:
self.linear = LoRA(input_dim, output_dim, low_rank_dim)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
if self.low_rank_dim is not None:
s += f", low_rank_dim={self.low_rank_dim}"
s += ")"
return s
def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
if delta is None:
shifted = self.time_shift(x)
if len(shifted.shape) == 2:
shifted = shifted.unsqueeze(1)
delta = shifted - x
return self.linear(x + delta * mu)

View File

@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.activations import ACT2FN
from fla.modules import FusedRMSNormSwishGate, RMSNorm
from fla.ops.simple_gla import chunk_simple_gla
class SimpleGatedLinearAttention(nn.Module):
r"""
The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
Args:
mode (str, Optional):
Which GLA kernel to use.
Currently available: `chunk`.
Default: `chunk`.
hidden_size (int, Optional):
The hidden size of the input. Default: 1024.
expand_k (float, Optional):
The expansion ratio for the key dim. Default: 0.5.
expand_v (float, Optional):
The expansion ratio for the value dim. Default: 1.0.
num_heads (int, Optional):
The number of heads. Default: 4.
gate_fn (str, Optional):
The activation function for the output gate. Default: `swish`.
elementwise_affine (bool, Optional):
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
norm_eps (float, Optional):
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
gate_logit_normalizer (int, Optional):
The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
fuse_norm (bool, Optional):
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
layer_idx (int, Optional):
The index of the layer. Default: None.
"""
def __init__(
self,
mode: str = 'chunk',
hidden_size: int = 1024,
expand_k: float = 1.0,
expand_v: float = 2.0,
num_heads: int = 4,
gate_fn: str = 'swish',
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-5,
gate_logit_normalizer: int = 16,
fuse_norm: bool = True,
**kwargs
) -> SimpleGatedLinearAttention:
super().__init__()
self.hidden_size = hidden_size
self.mode = mode
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
assert mode in ['chunk'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.num_heads = num_heads
self.head_qk_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.gate_fn = ACT2FN[gate_fn]
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
self.gk_proj = nn.Linear(hidden_size, self.num_heads)
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
if gate_fn == 'swish' and fuse_norm:
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
self.fuse_norm_and_gate = True
else:
self.fuse_norm_and_gate = False
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
self.gate_logit_normalizer = gate_logit_normalizer
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
if getattr(module, "_is_hf_initialized", False):
return
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
module._is_hf_initialized = True
def forward(self, x):
mode = self.mode
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
gk = rearrange(self.gk_proj(x), 'b n h -> b h n')
gk = (F.logsigmoid(gk) / self.gate_logit_normalizer)
if mode == 'chunk':
o = chunk_simple_gla(q, k, v, gk)
else:
raise NotImplementedError(f"Not supported mode `{mode}`.")
o = rearrange(o, 'b h l d -> b l h d')
g = self.g_proj(x)
if self.fuse_norm_and_gate:
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
o = self.g_norm_swish_gate(o, g)
o = rearrange(o, 'b l h d -> b l (h d)')
else:
o = self.g_norm(o)
o = rearrange(o, 'b l h d -> b l (h d)')
o = o * self.gate_fn(g)
o = self.o_proj(o)
return o
if __name__ == '__main__':
batch = 4
seq_len = 1024
hidden_size = 2048
x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
model = SimpleGatedLinearAttention(hidden_size=hidden_size, mode='chunk').to(torch.bfloat16).cuda()
y = model(x)
print(y.shape)
y.sum().backward()
print(x.grad.shape)