This commit is contained in:
25
finetune/lora/v6/fla/layers/__init__.py
vendored
Normal file
25
finetune/lora/v6/fla/layers/__init__.py
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .abc import ABCAttention
|
||||
from .based import BasedLinearAttention
|
||||
from .delta_net import DeltaNet
|
||||
from .gla import GatedLinearAttention
|
||||
from .hgrn import HGRNAttention
|
||||
from .hgrn2 import HGRN2Attention
|
||||
from .linear_attn import LinearAttention
|
||||
from .multiscale_retention import MultiScaleRetention
|
||||
from .rebased import ReBasedLinearAttention
|
||||
from .rwkv6 import RWKV6Attention
|
||||
|
||||
__all__ = [
|
||||
'ABCAttention',
|
||||
'BasedLinearAttention',
|
||||
'DeltaNet',
|
||||
'GatedLinearAttention',
|
||||
'HGRNAttention',
|
||||
'HGRN2Attention',
|
||||
'LinearAttention',
|
||||
'MultiScaleRetention',
|
||||
'ReBasedLinearAttention',
|
||||
'RWKV6Attention'
|
||||
]
|
||||
195
finetune/lora/v6/fla/layers/abc.py
vendored
Normal file
195
finetune/lora/v6/fla/layers/abc.py
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from fla.modules import (FusedRMSNormSwishGate, RMSNorm, RotaryEmbedding,
|
||||
ShortConvolution)
|
||||
from fla.modules.activations import swiglu, swish
|
||||
from fla.modules.convolution import proj_then_conv1d
|
||||
from fla.ops.abc.chunk import chunk_abc
|
||||
|
||||
|
||||
class ABCAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
expand_k: float = 0.5,
|
||||
expand_v: float = 1.0,
|
||||
num_heads: int = 4,
|
||||
use_short_conv: bool = False,
|
||||
conv_size: int = 4,
|
||||
conv_bias: bool = False,
|
||||
share_conv_kernel: bool = True,
|
||||
num_slots: Optional[int] = None,
|
||||
elementwise_affine: Optional[bool] = True,
|
||||
norm_eps: float = 1e-5,
|
||||
gate_low_rank_dim: int = 16,
|
||||
gate_logit_normalizer: int = 16,
|
||||
use_input_gate: bool = False,
|
||||
use_output_gate: bool = True,
|
||||
use_norm: bool = True,
|
||||
clamp_min: Optional[float] = -32,
|
||||
clamp_max: Optional[float] = 32,
|
||||
layer_idx: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> ABCAttention:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.expand_k = expand_k
|
||||
self.expand_v = expand_v
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = int(self.hidden_size * self.expand_k)
|
||||
self.value_dim = int(self.hidden_size * self.expand_v)
|
||||
self.head_k_dim = self.key_dim // self.num_heads
|
||||
self.head_v_dim = self.value_dim // self.num_heads
|
||||
|
||||
self.use_short_conv = use_short_conv
|
||||
self.conv_size = conv_size
|
||||
self.conv_bias = conv_bias
|
||||
self.share_conv_kernel = share_conv_kernel
|
||||
|
||||
self.gate_low_rank_dim = gate_low_rank_dim
|
||||
self.gate_logit_normalizer = gate_logit_normalizer
|
||||
|
||||
self.use_input_gate = use_input_gate
|
||||
self.use_output_gate = use_output_gate
|
||||
self.use_norm = use_norm
|
||||
|
||||
if num_slots is None:
|
||||
num_slots = self.head_k_dim
|
||||
self.num_slots = num_slots
|
||||
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
self.clamp_min = clamp_min
|
||||
self.clamp_max = clamp_max
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if layer_idx is None:
|
||||
warnings.warn(
|
||||
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
|
||||
|
||||
if use_output_gate:
|
||||
self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
|
||||
self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
|
||||
self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
|
||||
|
||||
if use_short_conv:
|
||||
self.conv_size = conv_size
|
||||
if share_conv_kernel:
|
||||
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
|
||||
else:
|
||||
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
||||
self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
||||
self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
|
||||
|
||||
if self.use_norm:
|
||||
if self.use_output_gate:
|
||||
self.g_norm = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
else:
|
||||
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
|
||||
if self.use_rope:
|
||||
self.rotary = RotaryEmbedding(self.head_k_dim)
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
hidden_states = self.h_conv1d(hidden_states)
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
else:
|
||||
q = proj_then_conv1d(hidden_states, self.q_proj.weight, self.q_conv1d.weight, self.q_conv1d.bias)
|
||||
k = proj_then_conv1d(hidden_states, self.k_proj.weight, self.k_conv1d.weight, self.k_conv1d.bias)
|
||||
v = proj_then_conv1d(hidden_states, self.v_proj.weight, self.v_conv1d.weight, self.v_conv1d.bias)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
if self.use_input_gate:
|
||||
q, k, v = map(lambda x: swish(x), (q, k, v))
|
||||
|
||||
if self.use_rope:
|
||||
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
|
||||
k = rearrange(k, '... (h d) -> ... h d', h=self.num_heads)
|
||||
seqlen_offset = 0
|
||||
if past_key_values is not None:
|
||||
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
|
||||
q, k = self.rotary(q, k, seqlen_offset)
|
||||
q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads)
|
||||
k = rearrange(k, 'b n h d -> b h n d', h=self.num_heads)
|
||||
else:
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
|
||||
# [batch_size, n_heads, seq_len, num_slots]
|
||||
s = rearrange(self.s_proj(hidden_states), 'b t (h m) -> b h t m', h=self.num_heads)
|
||||
s = s.clamp_(self.clamp_min, self.clamp_max)
|
||||
|
||||
last_state = past_key_values[self.layer_idx] if use_cache else None
|
||||
o, last_state = chunk_abc(q, k, v, s, initial_state=last_state, output_final_state=use_cache)
|
||||
if past_key_values is not None and last_state is not None:
|
||||
past_key_values.update(last_state, self.layer_idx, q.shape[2])
|
||||
|
||||
o = rearrange(o, 'b h t d -> b t h d')
|
||||
if self.use_norm and not self.use_output_gate:
|
||||
o = self.g_norm(o)
|
||||
elif self.use_output_gate:
|
||||
g = rearrange(self.g_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_heads)
|
||||
o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
|
||||
o = rearrange(o, 'b t h d -> b t (h d)')
|
||||
o = self.o_proj(o)
|
||||
|
||||
return o, None, past_key_values
|
||||
|
||||
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
|
||||
param = next(self.parameters())
|
||||
state = tuple()
|
||||
if self.use_short_conv:
|
||||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
|
||||
state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots),
|
||||
param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim))
|
||||
return state
|
||||
|
||||
def state_size(self, sequence_length: int = 2048):
|
||||
return self.num_heads * self.key_dim * self.head_v_dim
|
||||
126
finetune/lora/v6/fla/layers/based.py
vendored
Normal file
126
finetune/lora/v6/fla/layers/based.py
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Linear attention in Based.
|
||||
https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from fla.modules.feature_map import TaylorFeatureMap
|
||||
from fla.ops.based import parallel_based
|
||||
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
|
||||
|
||||
|
||||
class BasedLinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
l_max: int = 2048,
|
||||
feature_dim: int = 16,
|
||||
num_key_value_heads: int = 12,
|
||||
num_heads: int = 12,
|
||||
feature_name: str = "taylor_exp",
|
||||
eps: float = 1e-12,
|
||||
causal: bool = True,
|
||||
mode: str = "parallel",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size
|
||||
self.l_max = l_max
|
||||
self.mode = mode
|
||||
assert self.mode in ["fused_chunk", "parallel", 'chunk']
|
||||
|
||||
# linear attention
|
||||
self.feature_name = feature_name
|
||||
self.feature_dim = feature_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.hidden_size // self.num_key_value_heads
|
||||
self.causal = causal
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.dropout = nn.Identity()
|
||||
self.feature_map = TaylorFeatureMap(feature_dim)
|
||||
self.eps = eps
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, **kwargs):
|
||||
mode = self.mode
|
||||
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
|
||||
if mode == "fused_chunk":
|
||||
q, k = self.feature_map(q), self.feature_map(k)
|
||||
o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
|
||||
elif mode == 'chunk':
|
||||
q, k = self.feature_map(q), self.feature_map(k)
|
||||
o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
|
||||
elif mode == 'parallel':
|
||||
assert q.shape[-1] <= 128
|
||||
o = parallel_based(q, k, v, True, True)
|
||||
o = rearrange(o, "b h l d -> b l (h d)")
|
||||
o = self.o_proj(o)
|
||||
o = self.dropout(o)
|
||||
return o
|
||||
|
||||
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
|
||||
|
||||
def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
|
||||
"""
|
||||
x (torch.Tensor): tensor of shape (b, d, l)
|
||||
y (torch.Tensor): tensor of shape (b, d, l)
|
||||
"""
|
||||
# hidden_states = hidden_states.transpose(1, 2)
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
|
||||
q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
|
||||
k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
|
||||
v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Linear attention
|
||||
q, k = self.feature_map(q), self.feature_map(k)
|
||||
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
|
||||
|
||||
# Compute attention
|
||||
if self.causal:
|
||||
y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
|
||||
else:
|
||||
y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
|
||||
y = rearrange(y, 'b h l d -> b l (h d)')
|
||||
y = self.o_proj(y.to(hidden_states.dtype))
|
||||
y = self.dropout(y)
|
||||
return y.to(hidden_states.dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch = 4
|
||||
seq_len = 1024
|
||||
hidden_size = 1024
|
||||
dtype = torch.float32
|
||||
x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
|
||||
dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
|
||||
model = BasedLinearAttention(hidden_size, mode='chunk').to(dtype).cuda()
|
||||
y = model(x)
|
||||
y.backward(dy, retain_graph=True)
|
||||
x_grad, x.grad = x.grad, None
|
||||
y2 = model.forward_reference(x)
|
||||
y2.backward(dy)
|
||||
assert y.allclose(y2, 0, 1e-4), breakpoint()
|
||||
assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
|
||||
print("Pass")
|
||||
254
finetune/lora/v6/fla/layers/delta_net.py
vendored
Normal file
254
finetune/lora/v6/fla/layers/delta_net.py
vendored
Normal file
@@ -0,0 +1,254 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Sect4.2 of Linear Transformers Are Secretly Fast Weight Programmers https://arxiv.org/abs/2102.11174
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
|
||||
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution, LayerNorm
|
||||
from fla.modules.rotary import RotaryEmbedding
|
||||
from fla.ops.delta_rule import (fused_chunk_delta_rule,
|
||||
fused_recurrent_linear_attn_delta_rule,
|
||||
chunk_delta_rule)
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def simple_norm(x):
|
||||
return (F.normalize(x, dim=-1) * x.shape[-1] ** 0.5).to(x)
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def elu_p1(x):
|
||||
return (F.elu(x, 1., False) + 1.).to(x)
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def sum_norm(x):
|
||||
return (x / x.sum(-1, keepdim=True)).to(x)
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def elu_norm(x):
|
||||
dtype = x.dtype
|
||||
x = F.elu(x, 1., False) + 1.
|
||||
return (x / x.sum(-1, keepdim=True)).to(dtype)
|
||||
|
||||
|
||||
|
||||
|
||||
# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
|
||||
class DeltaNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = None,
|
||||
hidden_size: int = 1024,
|
||||
expand_k: float = 1.0,
|
||||
expand_v: float = 1.0,
|
||||
num_heads: int = 4,
|
||||
mode: str = 'fused_chunk',
|
||||
chunk_size: int = 16,
|
||||
use_beta: bool = True,
|
||||
use_gate: bool = True,
|
||||
use_rope: bool = False,
|
||||
use_output_norm: bool = True,
|
||||
use_elu: bool = False,
|
||||
use_short_conv: bool = True,
|
||||
conv_size: int = 4,
|
||||
conv_bias: bool = False,
|
||||
share_conv_kernel: bool = False,
|
||||
layer_idx: int = None,
|
||||
qk_activation: str = 'silu',
|
||||
qk_norm: str = None,
|
||||
save_memory: str = False,
|
||||
**kwargs
|
||||
) -> DeltaNet:
|
||||
super().__init__()
|
||||
self.mode = mode
|
||||
self.qk_activation = qk_activation
|
||||
self.qk_norm = qk_norm
|
||||
assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
|
||||
assert self.qk_norm in ['l2', 'sum']
|
||||
if d_model is not None:
|
||||
hidden_size = d_model
|
||||
self.hidden_size = hidden_size
|
||||
self.expand_k = expand_k
|
||||
self.expand_v = expand_v
|
||||
self.num_heads = num_heads
|
||||
self.chunk_size = chunk_size
|
||||
self.use_gate = use_gate
|
||||
self.use_output_norm = use_output_norm
|
||||
self.use_short_conv = use_short_conv
|
||||
self.conv_size = conv_size
|
||||
self.conv_bias = conv_bias
|
||||
self.share_conv_kernel = share_conv_kernel
|
||||
|
||||
self.key_dim = int(hidden_size * expand_k)
|
||||
self.value_dim = int(hidden_size * expand_v)
|
||||
self.head_qk_dim = self.key_dim // num_heads
|
||||
self.head_v_dim = self.value_dim // num_heads
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.silu = torch.nn.SiLU()
|
||||
|
||||
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
||||
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
||||
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||||
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||||
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
||||
|
||||
self.use_beta = use_beta
|
||||
self.use_elu = use_elu
|
||||
if self.use_beta:
|
||||
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
|
||||
if use_short_conv:
|
||||
self.conv_size = conv_size
|
||||
if share_conv_kernel:
|
||||
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation=None)
|
||||
else:
|
||||
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
|
||||
self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None)
|
||||
self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
|
||||
if use_gate:
|
||||
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||||
if self.use_gate:
|
||||
self.norm = FusedRMSNormSwishGate(self.head_v_dim)
|
||||
else:
|
||||
self.norm = RMSNorm(self.head_v_dim)
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
|
||||
# change to inference mode.
|
||||
mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode
|
||||
last_state = past_key_values[self.layer_idx] if use_cache else None
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.shape[-1] != hidden_states.shape[-2]:
|
||||
attention_mask = attention_mask[:, -1:]
|
||||
|
||||
if self.use_short_conv:
|
||||
conv_state = last_state[0] if use_cache else None
|
||||
if self.share_conv_kernel:
|
||||
# conv state is updated inplace
|
||||
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
else:
|
||||
conv_state_q = last_state[0] if use_cache else None
|
||||
conv_state_k = last_state[1] if use_cache else None
|
||||
conv_state_v = last_state[2] if use_cache else None
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
q = self.q_proj(hidden_states)
|
||||
q = self.q_conv1d(q, attention_mask, conv_state_q)
|
||||
k = self.k_conv1d(k, attention_mask, conv_state_k)
|
||||
v = self.v_conv1d(v, attention_mask, conv_state_v)
|
||||
else:
|
||||
q = (self.q_proj(hidden_states))
|
||||
k = (self.k_proj(hidden_states))
|
||||
v = self.silu(self.v_proj(hidden_states))
|
||||
|
||||
# dealing with left-padding
|
||||
if attention_mask is not None:
|
||||
v = v.mul_(attention_mask.unsqueeze(-1))
|
||||
|
||||
q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v))
|
||||
|
||||
if self.qk_activation != 'silu':
|
||||
if self.qk_activation == 'relu':
|
||||
q, k = q.relu(), k.relu()
|
||||
elif self.qk_activation == 'elu':
|
||||
q, k = elu_p1(q), elu_p1(k)
|
||||
elif self.qk_activation == 'identity':
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.qk_norm is not None:
|
||||
if self.qk_norm == 'l2':
|
||||
k = torch.nn.functional.normalize(k, dim=-1, p=2).to(v) #auto mixed precision type transfer is annoying.
|
||||
q = torch.nn.functional.normalize(q, dim=-1, p=2).to(v)
|
||||
elif self.qk_norm == 'sum':
|
||||
q = sum_norm(q).to(v)
|
||||
k = sum_norm(k).to(v)
|
||||
|
||||
if self.use_beta:
|
||||
beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid()
|
||||
else:
|
||||
beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
|
||||
state = past_key_values[self.layer_idx][-1] if use_cache else None
|
||||
if mode == 'fused_recurrent':
|
||||
o, recurrent_state = fused_recurrent_linear_attn_delta_rule(q, k, v, beta, state, output_final_state=use_cache)
|
||||
elif mode == 'fused_chunk':
|
||||
assert self.chunk_size in [16, 32, 64]
|
||||
o, recurrent_state = fused_chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
|
||||
elif mode == 'chunk':
|
||||
assert self.chunk_size in [16, 32, 64]
|
||||
o, recurrent_state = chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
||||
|
||||
if past_key_values is not None:
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
state = (conv_state, recurrent_state)
|
||||
else:
|
||||
state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
|
||||
else:
|
||||
state = (recurrent_state,)
|
||||
past_key_values.update(state, self.layer_idx)
|
||||
|
||||
o = rearrange(o, 'b h l d -> b l h d')
|
||||
if self.use_gate:
|
||||
g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads)
|
||||
o = self.norm(o, g)
|
||||
else:
|
||||
o = self.norm(o)
|
||||
o = rearrange(o, 'b l h d -> b l (h d)')
|
||||
o = self.o_proj(o)
|
||||
|
||||
return o, None, past_key_values
|
||||
|
||||
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
|
||||
param = next(self.parameters())
|
||||
state = tuple()
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
|
||||
else:
|
||||
# for q/k/v each
|
||||
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.value_dim, self.conv_size))
|
||||
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
|
||||
return state
|
||||
234
finetune/lora/v6/fla/layers/gated_abc.py
vendored
Normal file
234
finetune/lora/v6/fla/layers/gated_abc.py
vendored
Normal file
@@ -0,0 +1,234 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from fla.modules import (FusedRMSNormSwishGateLinear, RMSNormLinear,
|
||||
RotaryEmbedding, ShortConvolution)
|
||||
from fla.modules.activations import ACT2FN, swiglu_linear, swish
|
||||
from fla.ops.abc.chunk_gate import chunk_gated_abc
|
||||
|
||||
|
||||
class GatedABCAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
expand_k: float = 1.,
|
||||
expand_v: float = 1.,
|
||||
num_heads: int = 4,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
use_short_conv: bool = False,
|
||||
conv_size: int = 4,
|
||||
conv_bias: bool = False,
|
||||
share_conv_kernel: bool = True,
|
||||
num_slots: Optional[int] = None,
|
||||
elementwise_affine: Optional[bool] = True,
|
||||
norm_eps: float = 1e-5,
|
||||
gate_low_rank_dim: Optional[int] = None,
|
||||
gate_logit_normalizer: int = 16,
|
||||
feature_map: str = 'swish',
|
||||
use_rope: bool = False,
|
||||
use_output_gate: bool = False,
|
||||
use_norm: bool = True,
|
||||
layer_idx: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> GatedABCAttention:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.expand_k = expand_k
|
||||
self.expand_v = expand_v
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
||||
self.key_dim = int(hidden_size * expand_k)
|
||||
self.value_dim = int(hidden_size * expand_v)
|
||||
self.key_dim_per_group = self.key_dim // self.num_kv_groups
|
||||
self.value_dim_per_group = self.value_dim // self.num_kv_groups
|
||||
self.head_k_dim = self.key_dim // self.num_heads
|
||||
self.head_v_dim = self.value_dim // self.num_heads
|
||||
|
||||
self.use_short_conv = use_short_conv
|
||||
self.conv_size = conv_size
|
||||
self.conv_bias = conv_bias
|
||||
self.share_conv_kernel = share_conv_kernel
|
||||
|
||||
if gate_low_rank_dim is None:
|
||||
gate_low_rank_dim = self.hidden_size // 16
|
||||
self.gate_low_rank_dim = gate_low_rank_dim
|
||||
self.gate_logit_normalizer = gate_logit_normalizer
|
||||
|
||||
self.feature_map = feature_map
|
||||
self.use_rope = use_rope
|
||||
self.use_output_gate = use_output_gate
|
||||
self.use_norm = use_norm
|
||||
|
||||
if num_slots is None:
|
||||
num_slots = self.head_k_dim
|
||||
self.num_slots = num_slots
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if layer_idx is None:
|
||||
warnings.warn(
|
||||
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
|
||||
self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
|
||||
|
||||
if use_output_gate:
|
||||
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||||
|
||||
if use_short_conv:
|
||||
self.conv_size = conv_size
|
||||
if share_conv_kernel:
|
||||
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
|
||||
else:
|
||||
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
||||
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
|
||||
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
|
||||
|
||||
if self.use_norm:
|
||||
if self.use_output_gate:
|
||||
self.g_norm = FusedRMSNormSwishGateLinear(self.hidden_size, elementwise_affine, norm_eps)
|
||||
else:
|
||||
self.g_norm = RMSNormLinear(self.hidden_size, elementwise_affine, norm_eps)
|
||||
self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
|
||||
|
||||
if self.use_rope:
|
||||
self.rotary = RotaryEmbedding(self.head_k_dim)
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
lower_bound: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
|
||||
last_state = past_key_values[self.layer_idx] if use_cache else None
|
||||
if self.use_short_conv:
|
||||
conv_state = last_state[0] if use_cache else None
|
||||
if self.share_conv_kernel:
|
||||
# conv state is updated inplace
|
||||
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
else:
|
||||
conv_state_q = last_state[0] if use_cache else None
|
||||
conv_state_k = last_state[1] if use_cache else None
|
||||
conv_state_v = last_state[2] if use_cache else None
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
q = self.q_conv1d(q, attention_mask, conv_state_q)
|
||||
k = self.k_conv1d(k, attention_mask, conv_state_k)
|
||||
v = self.v_conv1d(v, attention_mask, conv_state_v)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
f = self.f_proj(hidden_states)
|
||||
|
||||
if self.use_rope:
|
||||
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
|
||||
k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
|
||||
seqlen_offset = 0
|
||||
if past_key_values is not None:
|
||||
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
|
||||
q, k = self.rotary(q, k, seqlen_offset)
|
||||
q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads)
|
||||
k = rearrange(k, 'b n h d -> b h n d', h=self.num_kv_heads)
|
||||
else:
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
if self.num_kv_groups > 1:
|
||||
k = repeat(k, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups)
|
||||
else:
|
||||
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_kv_heads)
|
||||
if self.num_kv_groups > 1:
|
||||
v = repeat(v, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups)
|
||||
f = repeat(f, 'b n (h m) -> b (h g) n m', h=self.num_kv_heads, g=self.num_kv_groups)
|
||||
else:
|
||||
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_kv_heads)
|
||||
f = rearrange(f, 'b n (h m) -> b h n m', h=self.num_kv_heads)
|
||||
|
||||
if self.feature_map is not None:
|
||||
q, k, v = map(lambda x: ACT2FN[self.feature_map](x), (q, k, v))
|
||||
f = F.logsigmoid(f) / self.gate_logit_normalizer
|
||||
s = (1 - f.exp()).to(f.dtype)
|
||||
# dealing with left-padding
|
||||
if attention_mask is not None:
|
||||
s = s.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1))
|
||||
v = v.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1))
|
||||
|
||||
recurrent_state = last_state[-2:] if use_cache else None
|
||||
o, recurrent_state = chunk_gated_abc(q, k, v, s, f,
|
||||
initial_state=recurrent_state,
|
||||
output_final_state=use_cache)
|
||||
if past_key_values is not None:
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
last_state = (conv_state,) + recurrent_state
|
||||
else:
|
||||
last_state = (conv_state_q, conv_state_k, conv_state_v) + recurrent_state
|
||||
else:
|
||||
last_state = recurrent_state
|
||||
past_key_values.update(last_state, self.layer_idx, q.shape[2])
|
||||
|
||||
o = rearrange(o, 'b h t d -> b t (h d)')
|
||||
if self.use_norm and not self.use_output_gate:
|
||||
o = swish(o)
|
||||
o = self.g_norm(o, self.o_proj.weight, self.o_proj.bias)
|
||||
elif self.use_output_gate and not self.use_norm:
|
||||
o = swiglu_linear(self.g_proj(hidden_states), o, self.o_proj.weight, self.o_proj.bias)
|
||||
elif self.use_output_gate and self.use_norm:
|
||||
o = self.g_norm(o, self.g_proj(hidden_states), self.o_proj.weight, self.o_proj.bias)
|
||||
else:
|
||||
o = self.o_proj(o)
|
||||
return o, None, past_key_values
|
||||
|
||||
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
|
||||
param = next(self.parameters())
|
||||
state = tuple()
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
|
||||
else:
|
||||
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.value_dim, self.conv_size))
|
||||
state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots),
|
||||
param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim))
|
||||
return state
|
||||
|
||||
def state_size(self, sequence_length: int = 2048):
|
||||
return self.num_heads * self.key_dim * self.head_v_dim
|
||||
268
finetune/lora/v6/fla/layers/gla.py
vendored
Normal file
268
finetune/lora/v6/fla/layers/gla.py
vendored
Normal file
@@ -0,0 +1,268 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
|
||||
from fla.modules.activations import ACT2FN
|
||||
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
|
||||
|
||||
|
||||
class GatedLinearAttention(nn.Module):
|
||||
r"""
|
||||
The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
|
||||
|
||||
Args:
|
||||
mode (str, Optional):
|
||||
Which GLA kernel to use.
|
||||
Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
|
||||
Default: `chunk`.
|
||||
hidden_size (int, Optional):
|
||||
The hidden size of the input. Default: 1024.
|
||||
expand_k (float, Optional):
|
||||
The expansion ratio for the key dim. Default: 0.5.
|
||||
expand_v (float, Optional):
|
||||
The expansion ratio for the value dim. Default: 1.0.
|
||||
num_heads (int, Optional):
|
||||
The number of heads. Default: 4.
|
||||
num_kv_heads (int, Optional):
|
||||
The number of key/value heads, used for MQA. Default: None.
|
||||
feature_map (str, Optional):
|
||||
Feature map function applied to queries/keys. Default: None.
|
||||
use_short_conv (bool, Optional):
|
||||
Whether to use short convolutions. Default: `False`.
|
||||
conv_size (int, Optional):
|
||||
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
|
||||
conv_bias (bool, Optional):
|
||||
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
|
||||
share_conv_kernel (bool, Optional):
|
||||
Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`.
|
||||
use_output_gate (bool, Optional):
|
||||
Whether to use output gate. Default: `True`.
|
||||
gate_fn (str, Optional):
|
||||
The activation function for the output gate. Default: `swish`.
|
||||
elementwise_affine (bool, Optional):
|
||||
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
|
||||
norm_eps (float, Optional):
|
||||
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
|
||||
gate_logit_normalizer (int, Optional):
|
||||
The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
|
||||
gate_low_rank_dim (int, Optional):
|
||||
The low rank dim for the gate projection. Default: 16.
|
||||
clamp_min (float, Optional):
|
||||
The minimum value for the gate logits. Default: None.
|
||||
fuse_norm (bool, Optional):
|
||||
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
|
||||
layer_idx (int, Optional):
|
||||
The index of the layer. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = 'chunk',
|
||||
hidden_size: int = 1024,
|
||||
expand_k: float = 0.5,
|
||||
expand_v: float = 1.0,
|
||||
num_heads: int = 4,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
feature_map: Optional[str] = None,
|
||||
use_short_conv: bool = False,
|
||||
conv_size: int = 4,
|
||||
conv_bias: bool = False,
|
||||
share_conv_kernel: bool = True,
|
||||
use_output_gate: bool = True,
|
||||
gate_fn: str = 'swish',
|
||||
elementwise_affine: Optional[bool] = True,
|
||||
norm_eps: float = 1e-5,
|
||||
gate_logit_normalizer: int = 16,
|
||||
gate_low_rank_dim: int = 16,
|
||||
clamp_min: Optional[float] = None,
|
||||
fuse_norm: bool = True,
|
||||
layer_idx: int = None,
|
||||
) -> GatedLinearAttention:
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.hidden_size = hidden_size
|
||||
self.expand_k = expand_k
|
||||
self.expand_v = expand_v
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
|
||||
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
||||
self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
|
||||
|
||||
self.use_short_conv = use_short_conv
|
||||
self.conv_size = conv_size
|
||||
self.conv_bias = conv_bias
|
||||
self.share_conv_kernel = share_conv_kernel
|
||||
self.use_output_gate = use_output_gate
|
||||
|
||||
self.key_dim = int(hidden_size * expand_k)
|
||||
self.value_dim = int(hidden_size * expand_v)
|
||||
self.key_dim_per_group = self.key_dim // self.num_kv_groups
|
||||
self.value_dim_per_group = self.value_dim // self.num_kv_groups
|
||||
self.clamp_min = clamp_min
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
|
||||
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
||||
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
||||
|
||||
self.head_qk_dim = self.key_dim // num_heads
|
||||
self.head_v_dim = self.value_dim // num_heads
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||||
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
|
||||
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
|
||||
if self.use_output_gate:
|
||||
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||||
|
||||
if use_short_conv:
|
||||
self.conv_size = conv_size
|
||||
if share_conv_kernel:
|
||||
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
|
||||
else:
|
||||
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
||||
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
|
||||
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
|
||||
|
||||
self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
|
||||
nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
|
||||
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
||||
|
||||
if gate_fn == 'swish' and fuse_norm and use_output_gate:
|
||||
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
self.fuse_norm_and_gate = True
|
||||
else:
|
||||
self.fuse_norm_and_gate = False
|
||||
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
self.gate_fn = ACT2FN[gate_fn]
|
||||
|
||||
self.gate_logit_normalizer = gate_logit_normalizer
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
# launching the triton kernel for just one token will actually be slower
|
||||
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
|
||||
|
||||
last_state = past_key_values[self.layer_idx] if use_cache else None
|
||||
if self.use_short_conv:
|
||||
conv_state = last_state[0] if use_cache else None
|
||||
if self.share_conv_kernel:
|
||||
# conv state is updated inplace
|
||||
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
else:
|
||||
conv_state_q = last_state[0] if use_cache else None
|
||||
conv_state_k = last_state[1] if use_cache else None
|
||||
conv_state_v = last_state[2] if use_cache else None
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
q = self.q_conv1d(q, attention_mask, conv_state_q)
|
||||
k = self.k_conv1d(k, attention_mask, conv_state_k)
|
||||
v = self.v_conv1d(v, attention_mask, conv_state_v)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
gk = self.gk_proj(hidden_states)
|
||||
|
||||
if self.feature_map_fn is not None:
|
||||
q, k = map(self.feature_map_fn, (q, k))
|
||||
# dealing with left-padding
|
||||
if attention_mask is not None:
|
||||
v = v.mul_(attention_mask.unsqueeze(-1))
|
||||
q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads)
|
||||
if self.num_kv_groups > 1:
|
||||
k, v, gk = (repeat(x, 'b l (h d) -> b (h g) l d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v, gk))
|
||||
else:
|
||||
k, v, gk = (rearrange(x, 'b l (h d) -> b h l d', h=self.num_kv_heads) for x in (k, v, gk))
|
||||
gk = F.logsigmoid(gk) / self.gate_logit_normalizer
|
||||
|
||||
if self.clamp_min is not None:
|
||||
gk = torch.clamp_min(gk, self.clamp_min)
|
||||
|
||||
recurrent_state = last_state[-1] if use_cache else None
|
||||
if mode == 'fused_recurrent':
|
||||
o, recurrent_state = fused_recurrent_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
|
||||
elif mode == 'fused_chunk':
|
||||
o, recurrent_state = fused_chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
|
||||
elif mode == 'chunk':
|
||||
o, recurrent_state = chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
||||
|
||||
if past_key_values is not None:
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
last_state = (conv_state, recurrent_state)
|
||||
else:
|
||||
last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
|
||||
else:
|
||||
last_state = (recurrent_state,)
|
||||
past_key_values.update(last_state, self.layer_idx, q.shape[2])
|
||||
|
||||
o = rearrange(o, 'b h l d -> b l h d')
|
||||
if self.use_output_gate:
|
||||
g = self.g_proj(hidden_states)
|
||||
if self.fuse_norm_and_gate:
|
||||
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
|
||||
o = self.g_norm_swish_gate(o, g)
|
||||
o = rearrange(o, 'b l h d -> b l (h d)')
|
||||
else:
|
||||
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
|
||||
o = o * self.gate_fn(g)
|
||||
else:
|
||||
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
|
||||
o = self.o_proj(o)
|
||||
|
||||
return o, None, past_key_values
|
||||
|
||||
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
|
||||
param = next(self.parameters())
|
||||
state = tuple()
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
|
||||
else:
|
||||
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.value_dim, self.conv_size))
|
||||
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
|
||||
return state
|
||||
|
||||
def state_size(self, **kwargs) -> int:
|
||||
state_size = self.key_dim * self.head_v_dim
|
||||
for module in self.children():
|
||||
if isinstance(module, ShortConvolution):
|
||||
state_size += module.state_size
|
||||
return state_size
|
||||
165
finetune/lora/v6/fla/layers/hgrn.py
vendored
Normal file
165
finetune/lora/v6/fla/layers/hgrn.py
vendored
Normal file
@@ -0,0 +1,165 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from fla.modules import FusedRMSNormSwishGate, ShortConvolution
|
||||
from fla.modules.activations import swiglu
|
||||
from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
|
||||
|
||||
|
||||
class HGRNAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = 'chunk',
|
||||
hidden_size: int = 1024,
|
||||
num_heads: Optional[int] = None,
|
||||
expand_ratio: Optional[int] = 1,
|
||||
use_short_conv: bool = False,
|
||||
conv_size: int = 4,
|
||||
conv_bias: bool = False,
|
||||
share_conv_kernel: bool = True,
|
||||
elementwise_affine: Optional[bool] = True,
|
||||
norm_eps: float = 1e-5,
|
||||
layer_idx: int = None
|
||||
) -> HGRNAttention:
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.expand_ratio = expand_ratio
|
||||
self.input_dim = int(hidden_size * expand_ratio)
|
||||
self.head_dim = self.input_dim // self.num_heads
|
||||
|
||||
self.use_short_conv = use_short_conv
|
||||
self.conv_size = conv_size
|
||||
self.conv_bias = conv_bias
|
||||
self.share_conv_kernel = share_conv_kernel
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
||||
assert self.hidden_size % num_heads == 0, f"hidden size must be divisible by num_heads of {num_heads}"
|
||||
|
||||
self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
|
||||
self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
|
||||
self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
|
||||
|
||||
if use_short_conv:
|
||||
self.conv_size = conv_size
|
||||
if share_conv_kernel:
|
||||
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
|
||||
else:
|
||||
self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
|
||||
self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
|
||||
self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
|
||||
|
||||
self.g_norm = FusedRMSNormSwishGate(self.input_dim, elementwise_affine, norm_eps)
|
||||
self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
lower_bound: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
# launching the triton kernel for just one token will actually be slower
|
||||
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
|
||||
|
||||
last_state = past_key_values[self.layer_idx] if use_cache else None
|
||||
if self.use_short_conv:
|
||||
conv_state = last_state[0] if use_cache else None
|
||||
if self.share_conv_kernel:
|
||||
# conv state is updated inplace
|
||||
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
|
||||
i = self.i_proj(hidden_states)
|
||||
f = self.f_proj(hidden_states)
|
||||
else:
|
||||
conv_state_i = last_state[2] if use_cache else None
|
||||
conv_state_f = last_state[1] if use_cache else None
|
||||
i = self.i_conv1d(self.i_proj(hidden_states), attention_mask, conv_state_i)
|
||||
f = self.f_conv1d(self.f_proj(hidden_states), attention_mask, conv_state_f)
|
||||
else:
|
||||
i = self.i_proj(hidden_states)
|
||||
f = self.f_proj(hidden_states)
|
||||
|
||||
# the lower bound for the first layer is zero
|
||||
if lower_bound is None or self.layer_idx == 0:
|
||||
i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
|
||||
else:
|
||||
g = lower_bound + (1 - lower_bound) * f.sigmoid()
|
||||
i, f = swiglu(i, 1 - g), g.log()
|
||||
|
||||
# dealing with left-padding
|
||||
if attention_mask is not None:
|
||||
i = i.mul_(attention_mask.unsqueeze(-1))
|
||||
i, f = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (i, f))
|
||||
|
||||
recurrent_state = last_state[-1] if use_cache else None
|
||||
if mode == 'chunk':
|
||||
o, recurrent_state = chunk_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache)
|
||||
elif mode == 'fused_recurrent':
|
||||
o, recurrent_state = fused_recurrent_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
||||
|
||||
if past_key_values is not None:
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
last_state = (conv_state, recurrent_state)
|
||||
else:
|
||||
last_state = (conv_state_i, conv_state_f, recurrent_state)
|
||||
else:
|
||||
last_state = (recurrent_state,)
|
||||
past_key_values.update(last_state, self.layer_idx, i.shape[2])
|
||||
|
||||
o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)'))
|
||||
o = self.o_proj(o)
|
||||
|
||||
return o, None, past_key_values
|
||||
|
||||
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
|
||||
param = next(self.parameters())
|
||||
state = tuple()
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
|
||||
else:
|
||||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),
|
||||
param.new_zeros(batch_size, self.hidden_size, self.conv_size),
|
||||
param.new_zeros(batch_size, self.hidden_size, self.conv_size))
|
||||
state += (param.new_zeros(batch_size, self.num_heads, self.head_dim),)
|
||||
return state
|
||||
|
||||
def state_size(self, **kwargs) -> int:
|
||||
state_size = self.hidden_size
|
||||
for module in self.children():
|
||||
if isinstance(module, ShortConvolution):
|
||||
state_size += module.state_size
|
||||
return state_size
|
||||
186
finetune/lora/v6/fla/layers/hgrn2.py
vendored
Normal file
186
finetune/lora/v6/fla/layers/hgrn2.py
vendored
Normal file
@@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from fla.modules import RMSNorm, ShortConvolution
|
||||
from fla.modules.activations import swish
|
||||
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
|
||||
|
||||
|
||||
class HGRN2Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = 'chunk',
|
||||
hidden_size: int = 1024,
|
||||
num_heads: Optional[int] = None,
|
||||
expand_ratio: Optional[int] = 128,
|
||||
use_short_conv: bool = False,
|
||||
conv_size: int = 4,
|
||||
conv_bias: bool = False,
|
||||
share_conv_kernel: bool = True,
|
||||
elementwise_affine: Optional[bool] = True,
|
||||
norm_eps: float = 1e-5,
|
||||
layer_idx: int = None
|
||||
) -> HGRN2Attention:
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
if expand_ratio is None and num_heads is not None:
|
||||
expand_ratio = hidden_size // num_heads
|
||||
elif expand_ratio is not None and num_heads is None:
|
||||
num_heads = hidden_size // expand_ratio
|
||||
else:
|
||||
raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
|
||||
self.num_heads = num_heads
|
||||
self.expand_ratio = expand_ratio
|
||||
|
||||
self.use_short_conv = use_short_conv
|
||||
self.conv_size = conv_size
|
||||
self.conv_bias = conv_bias
|
||||
self.share_conv_kernel = share_conv_kernel
|
||||
|
||||
self.forget_dim = int(self.num_heads * self.expand_ratio)
|
||||
self.input_dim = hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
|
||||
assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
|
||||
assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
|
||||
|
||||
self.head_f_dim = self.expand_ratio
|
||||
self.head_i_dim = self.hidden_size // num_heads
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
|
||||
self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
|
||||
self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
|
||||
|
||||
if use_short_conv:
|
||||
self.conv_size = conv_size
|
||||
if share_conv_kernel:
|
||||
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
|
||||
else:
|
||||
self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu')
|
||||
self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu')
|
||||
self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu')
|
||||
|
||||
self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, norm_eps)
|
||||
self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
lower_bound: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
# launching the triton kernel for just one token will actually be slower
|
||||
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
|
||||
|
||||
last_state = past_key_values[self.layer_idx] if use_cache else None
|
||||
if self.use_short_conv:
|
||||
conv_state = last_state[0] if use_cache else None
|
||||
if self.share_conv_kernel:
|
||||
# conv state is updated inplace
|
||||
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
|
||||
q = self.q_proj(hidden_states)
|
||||
f = self.f_proj(hidden_states)
|
||||
i = self.i_proj(hidden_states)
|
||||
else:
|
||||
conv_state_q = last_state[0] if use_cache else None
|
||||
conv_state_f = last_state[1] if use_cache else None
|
||||
conv_state_i = last_state[2] if use_cache else None
|
||||
q = self.q_proj(hidden_states)
|
||||
f = self.f_proj(hidden_states)
|
||||
i = self.i_proj(hidden_states)
|
||||
q = self.q_conv1d(q, attention_mask, conv_state_q)
|
||||
f = self.f_conv1d(f, attention_mask, conv_state_f)
|
||||
i = self.i_conv1d(i, attention_mask, conv_state_i)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)
|
||||
f = self.f_proj(hidden_states)
|
||||
i = self.i_proj(hidden_states)
|
||||
|
||||
# dealing with left-padding
|
||||
if attention_mask is not None:
|
||||
i = i.mul_(attention_mask.unsqueeze(-1))
|
||||
|
||||
q = swish(q)
|
||||
# the lower bound for the first layer is zero
|
||||
if lower_bound is None or self.layer_idx == 0:
|
||||
k, g = 1 - f.sigmoid(), F.logsigmoid(f)
|
||||
else:
|
||||
g = lower_bound + (1 - lower_bound) * f.sigmoid()
|
||||
k, g = 1 - g, g.log()
|
||||
q, k, i, g = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, i, g))
|
||||
|
||||
recurrent_state = last_state[-1] if use_cache else None
|
||||
if mode == 'fused_recurrent':
|
||||
o, recurrent_state = fused_recurrent_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
|
||||
elif mode == 'fused_chunk':
|
||||
o, recurrent_state = fused_chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
|
||||
elif mode == 'chunk':
|
||||
o, recurrent_state = chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
||||
|
||||
if past_key_values is not None:
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
last_state = (conv_state, recurrent_state)
|
||||
else:
|
||||
last_state = (conv_state_q, conv_state_f, conv_state_i, recurrent_state)
|
||||
else:
|
||||
last_state = (recurrent_state,)
|
||||
past_key_values.update(last_state, self.layer_idx, q.shape[2])
|
||||
|
||||
o = self.g_norm(rearrange(o, 'b h l d -> b l (h d)'))
|
||||
o = self.o_proj(o)
|
||||
|
||||
return o, None, past_key_values
|
||||
|
||||
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
|
||||
param = next(self.parameters())
|
||||
state = tuple()
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
|
||||
else:
|
||||
state += (param.new_zeros(batch_size, self.forget_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.forget_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.input_dim, self.conv_size))
|
||||
state += (param.new_zeros(batch_size, self.num_heads, self.head_f_dim, self.head_i_dim),)
|
||||
return state
|
||||
|
||||
def state_size(self, **kwargs) -> int:
|
||||
state_size = self.forget_dim * self.head_i_dim
|
||||
for module in self.children():
|
||||
if isinstance(module, ShortConvolution):
|
||||
state_size += module.state_size
|
||||
return state_size
|
||||
156
finetune/lora/v6/fla/layers/linear_attn.py
vendored
Normal file
156
finetune/lora/v6/fla/layers/linear_attn.py
vendored
Normal file
@@ -0,0 +1,156 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from fla.modules import RMSNorm
|
||||
from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap,
|
||||
HedgehogFeatureMap, T2RFeatureMap)
|
||||
from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn,
|
||||
fused_recurrent_linear_attn)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: str = 1024,
|
||||
expand_k: int = 1.0,
|
||||
expand_v: int = 1.0,
|
||||
num_heads: int = 8,
|
||||
mode: str = 'chunk',
|
||||
feature_map: str = 'elementwise_product',
|
||||
tie_feature_map_qk: bool = False,
|
||||
output_norm: str = 'rmsnorm',
|
||||
norm_q: bool = False,
|
||||
norm_k: bool = False,
|
||||
# standard linear attention normalization
|
||||
do_feature_map_norm: bool = False,
|
||||
elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp',
|
||||
'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`."
|
||||
|
||||
assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`."
|
||||
|
||||
self.hidden_size
|
||||
self.mode = mode
|
||||
self.key_dim = int(hidden_size * expand_k)
|
||||
self.value_dim = int(hidden_size * expand_v)
|
||||
self.num_heads = num_heads
|
||||
|
||||
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
||||
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
||||
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
||||
|
||||
self.head_qk_dim = self.key_dim // num_heads
|
||||
self.head_v_dim = self.value_dim // num_heads
|
||||
|
||||
if feature_map == 'hedgehog':
|
||||
if tie_feature_map_qk:
|
||||
self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
|
||||
else:
|
||||
self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim)
|
||||
self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim)
|
||||
|
||||
elif feature_map == 't2r':
|
||||
if tie_feature_map_qk:
|
||||
self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
|
||||
else:
|
||||
self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim)
|
||||
self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim)
|
||||
|
||||
elif feature_map == 'elementwise_product':
|
||||
if tie_feature_map_qk:
|
||||
self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
|
||||
else:
|
||||
self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim)
|
||||
self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim)
|
||||
|
||||
elif feature_map == 'dpfp':
|
||||
self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim)
|
||||
self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim)
|
||||
|
||||
elif feature_map == 'elu':
|
||||
def elu(x):
|
||||
return F.elu(x) + 1
|
||||
self.feature_map_q = elu
|
||||
self.feature_map_k = elu
|
||||
|
||||
elif feature_map == 'relu':
|
||||
self.feature_map_q = nn.ReLU()
|
||||
self.feature_map_k = nn.ReLU()
|
||||
|
||||
elif feature_map == 'identity':
|
||||
self.feature_map_q = nn.Identity()
|
||||
self.feature_map_k = nn.Identity()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.do_feature_map_norm = do_feature_map_norm
|
||||
if output_norm == 'rmsnorm':
|
||||
self.norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
elif output_norm == 'identity':
|
||||
self.norm = nn.Identity()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||||
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||||
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
||||
|
||||
self.norm_q = norm_q
|
||||
self.norm_k = norm_k
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(self, x):
|
||||
mode = self.mode
|
||||
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
q = self.feature_map_q(q)
|
||||
k = self.feature_map_k(k)
|
||||
if self.norm_q:
|
||||
q = q / (q.sum(-1, keepdim=True) + 1e-4)
|
||||
if self.norm_k:
|
||||
k = k / (k.sum(-1, keepdim=True) + 1e-4)
|
||||
|
||||
if mode == 'chunk':
|
||||
o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
|
||||
elif mode == 'fused_chunk':
|
||||
o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
|
||||
elif mode == 'fused_recurrent':
|
||||
o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
o = self.norm(o)
|
||||
o = rearrange(o, 'b h n d -> b n (h d)')
|
||||
o = self.o_proj(o)
|
||||
return o
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import torch
|
||||
batch = 4
|
||||
seq_len = 1024
|
||||
hidden_size = 1024
|
||||
x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
|
||||
model = LinearAttention(hidden_size, feature_map='dplp').to(torch.bfloat16).cuda()
|
||||
y = model(x)
|
||||
print(y.shape)
|
||||
y.sum().backward()
|
||||
print(x.grad.shape)
|
||||
271
finetune/lora/v6/fla/layers/multiscale_retention.py
vendored
Normal file
271
finetune/lora/v6/fla/layers/multiscale_retention.py
vendored
Normal file
@@ -0,0 +1,271 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
|
||||
from fla.modules.rotary import RotaryEmbedding
|
||||
from fla.ops.retention import (chunk_retention, fused_chunk_retention,
|
||||
fused_recurrent_retention, parallel_retention)
|
||||
|
||||
|
||||
class MultiScaleRetention(nn.Module):
|
||||
r"""
|
||||
The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa
|
||||
|
||||
Args:
|
||||
mode (str, Optional):
|
||||
Which Retention kernel to use.
|
||||
Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`.
|
||||
Default: `fused_chunk`.
|
||||
hidden_size (int, Optional):
|
||||
The hidden size of the input. Default: 1024.
|
||||
expand_k (float, Optional):
|
||||
The expansion ratio for the key dim. Default: 1.0.
|
||||
expand_v (float, Optional):
|
||||
The expansion ratio for the value dim. Default: 2.0.
|
||||
num_heads (int, Optional):
|
||||
The number of heads. Default: 8.
|
||||
num_kv_heads (int, Optional):
|
||||
The number of key/value heads, used for MQA. Default: None.
|
||||
feature_map (str, Optional):
|
||||
Feature map function applied to queries/keys. Default: None.
|
||||
use_short_conv (bool, Optional):
|
||||
Whether to use short convolutions. Default: `False`.
|
||||
conv_size (int, Optional):
|
||||
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
|
||||
conv_bias (bool, Optional):
|
||||
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
|
||||
share_conv_kernel (bool, Optional):
|
||||
Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`.
|
||||
use_output_gate (bool, Optional):
|
||||
Whether to use output gate. Default: `True`.
|
||||
gate_fn (str, Optional):
|
||||
The activation function for the output gate. Default: `swish`.
|
||||
elementwise_affine (bool, Optional):
|
||||
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
|
||||
norm_eps (float, Optional):
|
||||
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
|
||||
fuse_norm (bool, Optional):
|
||||
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
|
||||
layer_idx (int, Optional):
|
||||
The index of the layer. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = 'fused_chunk',
|
||||
hidden_size: int = 1024,
|
||||
expand_k: float = 1.0,
|
||||
expand_v: float = 2.0,
|
||||
num_heads: int = 8,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
feature_map: Optional[str] = None,
|
||||
use_short_conv: bool = False,
|
||||
conv_size: int = 4,
|
||||
conv_bias: bool = False,
|
||||
share_conv_kernel: bool = True,
|
||||
use_output_gate: bool = True,
|
||||
gate_fn: str = 'swish',
|
||||
elementwise_affine: Optional[bool] = True,
|
||||
norm_eps: float = 1e-5,
|
||||
fuse_norm: bool = True,
|
||||
layer_idx: int = None,
|
||||
**kwargs
|
||||
) -> MultiScaleRetention:
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.hidden_size = hidden_size
|
||||
self.expand_k = expand_k
|
||||
self.expand_v = expand_v
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
|
||||
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
||||
self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
|
||||
|
||||
self.use_short_conv = use_short_conv
|
||||
self.conv_size = conv_size
|
||||
self.conv_bias = conv_bias
|
||||
self.share_conv_kernel = share_conv_kernel
|
||||
self.use_output_gate = use_output_gate
|
||||
|
||||
self.key_dim = int(hidden_size * expand_k)
|
||||
self.value_dim = int(hidden_size * expand_v)
|
||||
self.key_dim_per_group = self.key_dim // self.num_kv_groups
|
||||
self.value_dim_per_group = self.value_dim // self.num_kv_groups
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
||||
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
||||
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
||||
|
||||
self.head_qk_dim = self.key_dim // num_heads
|
||||
self.head_v_dim = self.value_dim // num_heads
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||||
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
|
||||
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
|
||||
if self.use_output_gate:
|
||||
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||||
|
||||
if use_short_conv:
|
||||
self.conv_size = conv_size
|
||||
if share_conv_kernel:
|
||||
self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu')
|
||||
else:
|
||||
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
||||
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
|
||||
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
|
||||
|
||||
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
||||
|
||||
if gate_fn == 'swish' and fuse_norm and use_output_gate:
|
||||
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
self.fuse_norm_and_gate = True
|
||||
else:
|
||||
self.fuse_norm_and_gate = False
|
||||
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
self.gate_fn = ACT2FN[gate_fn]
|
||||
|
||||
# TODO: fix this issue
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180
|
||||
# Ideally, we would want to support arbitrary d_head_qk
|
||||
assert self.head_qk_dim <= 256, "head_qk_dim must be less than or equal to 256"
|
||||
self.rotary = RotaryEmbedding(dim=self.head_qk_dim)
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
# launching the triton kernel for just one token will actually be slower
|
||||
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
|
||||
|
||||
last_state = past_key_values[self.layer_idx] if use_cache else None
|
||||
if self.use_short_conv:
|
||||
conv_state = last_state[0] if use_cache else None
|
||||
if self.share_conv_kernel:
|
||||
# conv state is updated inplace
|
||||
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
else:
|
||||
conv_state_q = last_state[0] if use_cache else None
|
||||
conv_state_k = last_state[1] if use_cache else None
|
||||
conv_state_v = last_state[2] if use_cache else None
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
q = self.q_conv1d(q, attention_mask, conv_state_q)
|
||||
k = self.k_conv1d(k, attention_mask, conv_state_k)
|
||||
v = self.v_conv1d(v, attention_mask, conv_state_v)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
# dealing with left-padding
|
||||
if attention_mask is not None:
|
||||
v = v.mul_(attention_mask.unsqueeze(-1))
|
||||
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
|
||||
k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads)
|
||||
if self.feature_map_fn is not None:
|
||||
q, k = map(self.feature_map_fn, (q, k))
|
||||
|
||||
seqlen_offset, max_seqlen = 0, None
|
||||
if past_key_values is not None:
|
||||
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
|
||||
max_seqlen = q.shape[1] + seqlen_offset
|
||||
if attention_mask is not None:
|
||||
# to deliminate the offsets of padding tokens
|
||||
seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
|
||||
max_seqlen = q.shape[1] + max(seqlen_offset)
|
||||
q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
|
||||
q = q.transpose(1, 2)
|
||||
if self.num_kv_groups > 1:
|
||||
k = repeat(k, 'b t h d -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
|
||||
v = repeat(v, 'b t (h d) -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups)
|
||||
else:
|
||||
k, v = rearrange(k, 'b t h d -> b h t d'), rearrange(v, 'b t (h d) -> b h t d', h=self.num_kv_heads)
|
||||
|
||||
state = last_state[-1] if use_cache else None
|
||||
if mode == 'chunk':
|
||||
o, recurrent_state = chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache)
|
||||
elif mode == 'fused_chunk':
|
||||
o, recurrent_state = fused_chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache)
|
||||
elif mode == 'parallel':
|
||||
o, recurrent_state = parallel_retention(q, k, v, initial_state=state, output_final_state=use_cache)
|
||||
elif mode == 'fused_recurrent':
|
||||
o, recurrent_state = fused_recurrent_retention(q, k, v, initial_state=state, output_final_state=use_cache)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
||||
|
||||
if past_key_values is not None:
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
last_state = (conv_state, recurrent_state)
|
||||
else:
|
||||
last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state)
|
||||
else:
|
||||
last_state = (recurrent_state,)
|
||||
past_key_values.update(last_state, self.layer_idx, q.shape[2])
|
||||
|
||||
o = rearrange(o, 'b h l d -> b l h d')
|
||||
if self.use_output_gate:
|
||||
g = self.g_proj(hidden_states)
|
||||
if self.fuse_norm_and_gate:
|
||||
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
|
||||
o = self.g_norm_swish_gate(o, g)
|
||||
o = rearrange(o, 'b l h d -> b l (h d)')
|
||||
else:
|
||||
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
|
||||
o = o * self.gate_fn(g)
|
||||
else:
|
||||
o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
|
||||
o = self.o_proj(o)
|
||||
|
||||
return o, None, past_key_values
|
||||
|
||||
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
|
||||
param = next(self.parameters())
|
||||
state = tuple()
|
||||
if self.use_short_conv:
|
||||
if self.share_conv_kernel:
|
||||
state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),)
|
||||
else:
|
||||
state += (param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.key_dim, self.conv_size),
|
||||
param.new_zeros(batch_size, self.value_dim, self.conv_size))
|
||||
state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
|
||||
return state
|
||||
|
||||
def state_size(self, **kwargs) -> int:
|
||||
state_size = self.key_dim * self.head_v_dim
|
||||
for module in self.children():
|
||||
if isinstance(module, ShortConvolution):
|
||||
state_size += module.state_size
|
||||
return state_size
|
||||
137
finetune/lora/v6/fla/layers/rebased.py
vendored
Normal file
137
finetune/lora/v6/fla/layers/rebased.py
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from fla.modules.feature_map import RebasedFeatureMap
|
||||
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
|
||||
from fla.ops.rebased import parallel_rebased
|
||||
|
||||
|
||||
class ReBasedLinearAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
l_max: int = 2048,
|
||||
feature_dim: int = 16,
|
||||
num_key_value_heads: int = 16,
|
||||
num_heads: int = 16,
|
||||
use_gamma: Optional[bool] = True,
|
||||
use_beta: Optional[bool] = True,
|
||||
normalize: Optional[bool] = True,
|
||||
causal: bool = True,
|
||||
eps: float = 1e-5,
|
||||
mode: str = "parallel",
|
||||
layer_idx: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> ReBasedLinearAttention:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.l_max = l_max
|
||||
self.mode = mode
|
||||
assert self.mode in ["fused_chunk", "parallel", 'chunk']
|
||||
|
||||
# linear attention
|
||||
self.feature_dim = feature_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.hidden_size // self.num_key_value_heads
|
||||
self.use_gamma = use_gamma
|
||||
self.use_beta = use_beta
|
||||
self.normalize = normalize
|
||||
self.causal = causal
|
||||
|
||||
self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.dropout = nn.Identity()
|
||||
self.eps = eps
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, **kwargs):
|
||||
mode = self.mode
|
||||
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v])
|
||||
q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
|
||||
if mode == "fused_chunk":
|
||||
o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1)
|
||||
elif mode == 'chunk':
|
||||
o = chunk_linear_attn(q, k, v, normalize=True, scale=1)
|
||||
elif mode == 'parallel':
|
||||
assert q.shape[-1] <= 128
|
||||
o = parallel_rebased(q, k, v, self.eps, True, True)
|
||||
o = rearrange(o, "b h l d -> b l (h d)")
|
||||
o = self.o_proj(o)
|
||||
o = self.dropout(o)
|
||||
return o
|
||||
|
||||
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
|
||||
def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
|
||||
"""
|
||||
x (torch.Tensor): tensor of shape (b, d, l)
|
||||
y (torch.Tensor): tensor of shape (b, d, l)
|
||||
"""
|
||||
# hidden_states = hidden_states.transpose(1, 2)
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
|
||||
q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2)
|
||||
k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
|
||||
v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Linear attention
|
||||
q, k = self.feature_map(q), self.feature_map(k)
|
||||
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
|
||||
|
||||
# Compute attention
|
||||
if self.causal:
|
||||
y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
|
||||
else:
|
||||
y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
|
||||
y = rearrange(y, 'b h l d -> b l (h d)')
|
||||
y = self.o_proj(y.to(hidden_states.dtype))
|
||||
y = self.dropout(y)
|
||||
return y.to(hidden_states.dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch = 4
|
||||
seq_len = 1024
|
||||
hidden_size = 1024
|
||||
dtype = torch.float32
|
||||
x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True)
|
||||
dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda()
|
||||
model = ReBasedLinearAttention(hidden_size=hidden_size, mode='parallel').to(dtype).cuda()
|
||||
|
||||
y = model(x)
|
||||
y.backward(dy, retain_graph=True)
|
||||
x_grad, x.grad = x.grad, None
|
||||
print(model.mode)
|
||||
model.mode = 'fused_chunk'
|
||||
y2 = model(x)
|
||||
print(model.mode)
|
||||
y2.backward(dy)
|
||||
# assert y.allclose(y2, 0, 1e-4), breakpoint()
|
||||
# assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint()
|
||||
print("Pass")
|
||||
264
finetune/lora/v6/fla/layers/rwkv6.py
vendored
Normal file
264
finetune/lora/v6/fla/layers/rwkv6.py
vendored
Normal file
@@ -0,0 +1,264 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from fla.modules import FusedLayerNormSwishGate, LayerNorm
|
||||
from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
|
||||
|
||||
|
||||
class RWKV6Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = 'chunk',
|
||||
hidden_size: int = 1024,
|
||||
expand_k: float = 0.5,
|
||||
expand_v: float = 1.0,
|
||||
num_heads: int = 4,
|
||||
gate_fn: str = 'swish',
|
||||
proj_low_rank_dim: int = 32,
|
||||
gate_low_rank_dim: int = 64,
|
||||
fuse_norm: bool = True,
|
||||
elementwise_affine: Optional[bool] = True,
|
||||
norm_eps: float = 1e-5,
|
||||
layer_idx: int = None,
|
||||
**kwargs
|
||||
) -> RWKV6Attention:
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.hidden_size = hidden_size
|
||||
self.expand_k = expand_k
|
||||
self.expand_v = expand_v
|
||||
self.num_heads = num_heads
|
||||
self.proj_low_rank_dim = proj_low_rank_dim
|
||||
self.gate_low_rank_dim = gate_low_rank_dim
|
||||
|
||||
self.key_dim = int(hidden_size * expand_k)
|
||||
self.value_dim = int(hidden_size * expand_v)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
||||
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
||||
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
||||
|
||||
self.head_qk_dim = self.key_dim // num_heads
|
||||
self.head_v_dim = self.value_dim // num_heads
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
self.x_proj = nn.Sequential(
|
||||
LerpLinear(hidden_size, proj_low_rank_dim * 5),
|
||||
nn.Tanh(),
|
||||
nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=True)
|
||||
)
|
||||
self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
|
||||
self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
|
||||
self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
|
||||
self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
|
||||
self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
|
||||
self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_qk_dim))
|
||||
|
||||
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
||||
|
||||
if gate_fn == 'swish' and fuse_norm:
|
||||
self.g_norm_swish_gate = FusedLayerNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
self.fuse_norm_and_gate = True
|
||||
else:
|
||||
self.fuse_norm_and_gate = False
|
||||
self.g_norm = LayerNorm(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
self.gate_fn = ACT2FN[gate_fn]
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
if isinstance(module, nn.Parameter):
|
||||
nn.init.xavier_uniform_(module, gain=2 ** -2.5)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
batch_size, seq_len, hidden_size = hidden_states.size()
|
||||
# launching the triton kernel for just one token will actually be slower
|
||||
mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode
|
||||
|
||||
delta = self.time_shift(hidden_states) - hidden_states
|
||||
x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
|
||||
r, w, k, v, g = torch.einsum('b l n r, n r d-> b l n d',
|
||||
self.x_proj[1](x),
|
||||
self.x_proj[2].weight.view(5, -1, hidden_size)).unbind(-2)
|
||||
r = self.r_proj(hidden_states, r, delta)
|
||||
w = self.w_proj(hidden_states, w, delta)
|
||||
k = self.k_proj(hidden_states, k, delta)
|
||||
v = self.v_proj(hidden_states, v, delta)
|
||||
g = self.g_proj(hidden_states, g, delta)
|
||||
|
||||
# dealing with left-padding
|
||||
if attention_mask is not None:
|
||||
v = v.mul_(attention_mask.unsqueeze(-1))
|
||||
r, w, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (r, w, k, v))
|
||||
w = -torch.exp(w)
|
||||
u = self.bonus
|
||||
|
||||
last_state = past_key_values[self.layer_idx] if use_cache else None
|
||||
state = last_state[-1] if use_cache else None
|
||||
if mode == 'fused_recurrent':
|
||||
o, recurrent_state = fused_recurrent_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache)
|
||||
elif mode == 'chunk':
|
||||
o, recurrent_state = chunk_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values.update((recurrent_state,), self.layer_idx, r.shape[2])
|
||||
|
||||
o = rearrange(o, 'b h l d -> b l h d')
|
||||
if self.fuse_norm_and_gate:
|
||||
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
|
||||
o = self.g_norm_swish_gate(o, g)
|
||||
o = rearrange(o, 'b l h d -> b l (h d)')
|
||||
else:
|
||||
o = self.g_norm(o)
|
||||
o = rearrange(o, 'b l h d -> b l (h d)')
|
||||
o = o * self.gate_fn(g)
|
||||
o = self.o_proj(o)
|
||||
|
||||
return o, None, past_key_values
|
||||
|
||||
def init_state(self, batch_size: int) -> Tuple[torch.Tensor]:
|
||||
param = next(self.parameters())
|
||||
state = (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),)
|
||||
return state
|
||||
|
||||
def state_size(self, **kwargs) -> int:
|
||||
state_size = self.key_dim * self.head_v_dim
|
||||
return state_size
|
||||
|
||||
|
||||
class LoRA(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
low_rank_dim: int,
|
||||
bias: Optional[bool] = True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.low_rank_dim = low_rank_dim
|
||||
self.bias = bias
|
||||
|
||||
self.lora = nn.Sequential(
|
||||
nn.Linear(input_dim, low_rank_dim, bias=False),
|
||||
nn.Tanh(),
|
||||
nn.Linear(low_rank_dim, output_dim, bias=bias)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}("
|
||||
s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
|
||||
if not self.bias:
|
||||
s += f", bias={self.bias}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lora(x)
|
||||
|
||||
|
||||
class LerpLinear(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
low_rank_dim: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.low_rank_dim = low_rank_dim
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
if low_rank_dim is None:
|
||||
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
||||
else:
|
||||
self.linear = LoRA(input_dim, output_dim, low_rank_dim)
|
||||
self.mu = nn.Parameter(torch.zeros(input_dim))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
|
||||
if self.low_rank_dim is not None:
|
||||
s += f", low_rank_dim={self.low_rank_dim}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if delta is None:
|
||||
shifted = self.time_shift(x)
|
||||
if len(shifted.shape) == 2:
|
||||
shifted = shifted.unsqueeze(1)
|
||||
delta = shifted - x
|
||||
return self.linear(x + delta * self.mu)
|
||||
|
||||
|
||||
class DDLerpLinear(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
low_rank_dim: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.low_rank_dim = low_rank_dim
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
if low_rank_dim is None:
|
||||
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
||||
else:
|
||||
self.linear = LoRA(input_dim, output_dim, low_rank_dim)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
|
||||
if self.low_rank_dim is not None:
|
||||
s += f", low_rank_dim={self.low_rank_dim}"
|
||||
s += ")"
|
||||
return s
|
||||
|
||||
def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if delta is None:
|
||||
shifted = self.time_shift(x)
|
||||
if len(shifted.shape) == 2:
|
||||
shifted = shifted.unsqueeze(1)
|
||||
delta = shifted - x
|
||||
return self.linear(x + delta * mu)
|
||||
143
finetune/lora/v6/fla/layers/simple_gla.py
vendored
Normal file
143
finetune/lora/v6/fla/layers/simple_gla.py
vendored
Normal file
@@ -0,0 +1,143 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from fla.modules import FusedRMSNormSwishGate, RMSNorm
|
||||
from fla.ops.simple_gla import chunk_simple_gla
|
||||
|
||||
|
||||
class SimpleGatedLinearAttention(nn.Module):
|
||||
r"""
|
||||
The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
|
||||
This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
|
||||
|
||||
Args:
|
||||
mode (str, Optional):
|
||||
Which GLA kernel to use.
|
||||
Currently available: `chunk`.
|
||||
Default: `chunk`.
|
||||
hidden_size (int, Optional):
|
||||
The hidden size of the input. Default: 1024.
|
||||
expand_k (float, Optional):
|
||||
The expansion ratio for the key dim. Default: 0.5.
|
||||
expand_v (float, Optional):
|
||||
The expansion ratio for the value dim. Default: 1.0.
|
||||
num_heads (int, Optional):
|
||||
The number of heads. Default: 4.
|
||||
gate_fn (str, Optional):
|
||||
The activation function for the output gate. Default: `swish`.
|
||||
elementwise_affine (bool, Optional):
|
||||
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
|
||||
norm_eps (float, Optional):
|
||||
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
|
||||
gate_logit_normalizer (int, Optional):
|
||||
The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
|
||||
fuse_norm (bool, Optional):
|
||||
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
|
||||
layer_idx (int, Optional):
|
||||
The index of the layer. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = 'chunk',
|
||||
hidden_size: int = 1024,
|
||||
expand_k: float = 1.0,
|
||||
expand_v: float = 2.0,
|
||||
num_heads: int = 4,
|
||||
gate_fn: str = 'swish',
|
||||
elementwise_affine: Optional[bool] = True,
|
||||
norm_eps: float = 1e-5,
|
||||
gate_logit_normalizer: int = 16,
|
||||
fuse_norm: bool = True,
|
||||
**kwargs
|
||||
) -> SimpleGatedLinearAttention:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.mode = mode
|
||||
self.key_dim = int(hidden_size * expand_k)
|
||||
self.value_dim = int(hidden_size * expand_v)
|
||||
assert mode in ['chunk'], f"Not suppoerted mode `{mode}`."
|
||||
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
||||
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
||||
self.num_heads = num_heads
|
||||
self.head_qk_dim = self.key_dim // num_heads
|
||||
self.head_v_dim = self.value_dim // num_heads
|
||||
self.gate_fn = ACT2FN[gate_fn]
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||||
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
||||
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||||
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
||||
|
||||
self.gk_proj = nn.Linear(hidden_size, self.num_heads)
|
||||
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
||||
|
||||
if gate_fn == 'swish' and fuse_norm:
|
||||
self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
self.fuse_norm_and_gate = True
|
||||
else:
|
||||
self.fuse_norm_and_gate = False
|
||||
self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps)
|
||||
|
||||
self.gate_logit_normalizer = gate_logit_normalizer
|
||||
|
||||
self.apply(self._initialize_weights)
|
||||
|
||||
def _initialize_weights(self, module: nn.Module):
|
||||
if getattr(module, "_is_hf_initialized", False):
|
||||
return
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
module._is_hf_initialized = True
|
||||
|
||||
def forward(self, x):
|
||||
mode = self.mode
|
||||
q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
||||
gk = rearrange(self.gk_proj(x), 'b n h -> b h n')
|
||||
gk = (F.logsigmoid(gk) / self.gate_logit_normalizer)
|
||||
|
||||
if mode == 'chunk':
|
||||
o = chunk_simple_gla(q, k, v, gk)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
||||
|
||||
o = rearrange(o, 'b h l d -> b l h d')
|
||||
g = self.g_proj(x)
|
||||
|
||||
if self.fuse_norm_and_gate:
|
||||
g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
|
||||
o = self.g_norm_swish_gate(o, g)
|
||||
o = rearrange(o, 'b l h d -> b l (h d)')
|
||||
else:
|
||||
o = self.g_norm(o)
|
||||
o = rearrange(o, 'b l h d -> b l (h d)')
|
||||
o = o * self.gate_fn(g)
|
||||
o = self.o_proj(o)
|
||||
return o
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch = 4
|
||||
seq_len = 1024
|
||||
|
||||
hidden_size = 2048
|
||||
x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True)
|
||||
model = SimpleGatedLinearAttention(hidden_size=hidden_size, mode='chunk').to(torch.bfloat16).cuda()
|
||||
y = model(x)
|
||||
print(y.shape)
|
||||
y.sum().backward()
|
||||
print(x.grad.shape)
|
||||
Reference in New Issue
Block a user