337 lines
11 KiB
Python
337 lines
11 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
|
||
|
|
||
|
import math
|
||
|
import warnings
|
||
|
from typing import Optional
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from einops import rearrange
|
||
|
|
||
|
from fla.modules.activations import ACT2FN
|
||
|
from fla.utils import checkpoint
|
||
|
|
||
|
try:
|
||
|
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||
|
except ImportError:
|
||
|
causal_conv1d_fn = None
|
||
|
causal_conv1d_update = None
|
||
|
|
||
|
|
||
|
def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
|
||
|
seqlen = u.shape[-1]
|
||
|
fft_size = 2 * seqlen
|
||
|
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
|
||
|
if k_rev is not None:
|
||
|
k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
|
||
|
k_f = k_f + k_rev_f.conj()
|
||
|
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
|
||
|
|
||
|
if len(u.shape) > 3:
|
||
|
k_f = k_f.unsqueeze(1)
|
||
|
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
|
||
|
|
||
|
out = y + u
|
||
|
if gelu:
|
||
|
out = F.gelu(out)
|
||
|
if dropout_mask is not None:
|
||
|
return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
|
||
|
else:
|
||
|
return out.to(dtype=u.dtype)
|
||
|
|
||
|
|
||
|
@checkpoint
|
||
|
def proj_then_conv1d(
|
||
|
x: torch.Tensor,
|
||
|
proj_weight: torch.Tensor,
|
||
|
conv1d_weight: torch.Tensor,
|
||
|
conv1d_bias: Optional[torch.Tensor] = None,
|
||
|
cache: Optional[torch.Tensor] = None
|
||
|
) -> torch.Tensor:
|
||
|
# We do matmul and transpose BLH -> HBL at the same time
|
||
|
x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2])
|
||
|
|
||
|
if causal_conv1d_fn is None:
|
||
|
raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
|
||
|
if cache is None:
|
||
|
x = causal_conv1d_fn(
|
||
|
x=x,
|
||
|
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
|
||
|
bias=conv1d_bias,
|
||
|
activation="silu",
|
||
|
).transpose(1, 2)
|
||
|
else:
|
||
|
assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
|
||
|
x = x.squeeze(-1)
|
||
|
x = causal_conv1d_update(
|
||
|
x=x,
|
||
|
weight=rearrange(conv1d_weight, "d 1 w -> d w"),
|
||
|
bias=conv1d_bias,
|
||
|
cache=cache,
|
||
|
activation="silu",
|
||
|
)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class ShortConvolution(nn.Conv1d):
|
||
|
"""
|
||
|
Simple wrapper around `nn.Conv1d` that accepts dimension last.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_size: int,
|
||
|
kernel_size: int,
|
||
|
bias: bool = False,
|
||
|
activation: Optional[str] = 'silu',
|
||
|
use_causal_conv: Optional[bool] = True
|
||
|
):
|
||
|
super().__init__(in_channels=hidden_size,
|
||
|
out_channels=hidden_size,
|
||
|
kernel_size=kernel_size,
|
||
|
groups=hidden_size,
|
||
|
bias=bias,
|
||
|
padding=kernel_size - 1)
|
||
|
|
||
|
self.hidden_size = hidden_size
|
||
|
self.activation = None
|
||
|
if activation is not None:
|
||
|
assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
|
||
|
self.activation = activation
|
||
|
|
||
|
if use_causal_conv:
|
||
|
if causal_conv1d_fn is None:
|
||
|
warnings.warn("Please install `causal-conv1d` to use causal convolutions, setting `use_causal_conv` to False.")
|
||
|
use_causal_conv = False
|
||
|
self.use_causal_conv = use_causal_conv
|
||
|
|
||
|
def extra_repr(self):
|
||
|
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
|
||
|
', stride={stride}')
|
||
|
if self.padding != (0,) * len(self.padding):
|
||
|
s += ', padding={padding}'
|
||
|
if self.dilation != (1,) * len(self.dilation):
|
||
|
s += ', dilation={dilation}'
|
||
|
if self.output_padding != (0,) * len(self.output_padding):
|
||
|
s += ', output_padding={output_padding}'
|
||
|
if self.groups != 1:
|
||
|
s += ', groups={groups}'
|
||
|
if self.bias is None:
|
||
|
s += ', bias=False'
|
||
|
if self.padding_mode != 'zeros':
|
||
|
s += ', padding_mode={padding_mode}'
|
||
|
if self.activation is not None:
|
||
|
s += ', activation={activation}'
|
||
|
if not self.use_causal_conv:
|
||
|
s += ', use_causal_conv={use_causal_conv}'
|
||
|
return s.format(**self.__dict__)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None,
|
||
|
cache: Optional[torch.Tensor] = None
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Args:
|
||
|
x (`torch.Tensor`):
|
||
|
Tensor of shape `[batch_size, seq_len, hidden_size]`
|
||
|
mask (`Optional[torch.Tensor]`):
|
||
|
Attention mask dealing with padded positions.
|
||
|
cache (`Optional[torch.Tensor]`):
|
||
|
Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`,
|
||
|
Returns:
|
||
|
Tensor of shape `[batch_size, seq_len, hidden_size]`. The `cache` (if provided) is updated inplace.
|
||
|
"""
|
||
|
|
||
|
if mask is not None:
|
||
|
x = x.mul_(mask.unsqueeze(-1))
|
||
|
if cache is not None and x.shape[1] == 1:
|
||
|
return self.step(x, cache)
|
||
|
x = rearrange(x, "b l d -> b d l")
|
||
|
# Update state (B D W)
|
||
|
if cache is not None:
|
||
|
cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0)))
|
||
|
if self.use_causal_conv:
|
||
|
x = causal_conv1d_fn(
|
||
|
x=x,
|
||
|
weight=rearrange(self.weight, "d 1 w -> d w"),
|
||
|
bias=self.bias,
|
||
|
activation=self.activation,
|
||
|
)
|
||
|
else:
|
||
|
x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
|
||
|
if self.activation is not None:
|
||
|
x = ACT2FN[self.activation](x)
|
||
|
return rearrange(x, "b d l -> b l d")
|
||
|
|
||
|
def step(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
cache: torch.Tensor
|
||
|
):
|
||
|
assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
||
|
|
||
|
x = x.squeeze(1)
|
||
|
if self.use_causal_conv:
|
||
|
x = causal_conv1d_update(
|
||
|
x=x,
|
||
|
conv_state=cache,
|
||
|
weight=rearrange(self.weight, "d 1 w -> d w"),
|
||
|
bias=self.bias,
|
||
|
activation=self.activation,
|
||
|
)
|
||
|
else:
|
||
|
dtype = x.dtype
|
||
|
cache.copy_(torch.roll(cache, shifts=-1, dims=-1))
|
||
|
cache[:, :, -1] = x
|
||
|
x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
|
||
|
if self.bias is not None:
|
||
|
x = x + self.bias
|
||
|
if self.activation is not None:
|
||
|
x = ACT2FN[self.activation](x).to(dtype=dtype)
|
||
|
return x.unsqueeze(1)
|
||
|
|
||
|
@property
|
||
|
def state_size(self) -> int:
|
||
|
return self.hidden_size * self.kernel_size
|
||
|
|
||
|
|
||
|
class LongConvolution(nn.Module):
|
||
|
"""
|
||
|
LongConvolution applies a convolution operation on the input tensor using a fixed
|
||
|
filter of length l_max.
|
||
|
The filter is learned during training and is applied using FFT convolution.
|
||
|
Args:
|
||
|
hidden_size (int): The number of expected features in the input and output.
|
||
|
l_max (int): The maximum sequence length.
|
||
|
Returns:
|
||
|
y: (b, l, d) tensor
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_size: int,
|
||
|
l_max: int,
|
||
|
**kwargs,
|
||
|
):
|
||
|
"""
|
||
|
Initializes the LongConvolution module.
|
||
|
Args:
|
||
|
hidden_size (int): The number of expected features in the input and output.
|
||
|
l_max (int): The maximum sequence length.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.hidden_size = hidden_size
|
||
|
self.filter = nn.Parameter(torch.randn(self.hidden_size, l_max), requires_grad=True)
|
||
|
|
||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||
|
"""
|
||
|
Applies the LongConvolution operation on the input tensor.
|
||
|
Args:
|
||
|
x: (b, l, d) tensor
|
||
|
Returns:
|
||
|
y: (b, l, d) tensor
|
||
|
"""
|
||
|
x = x.transpose(1, 2)
|
||
|
y = fft_conv(x, self.filter, dropout_mask=None, gelu=False)
|
||
|
y = y.transpose(1, 2)
|
||
|
return y.to(dtype=x.dtype)
|
||
|
|
||
|
|
||
|
class PositionalEmbedding(nn.Module):
|
||
|
def __init__(self, emb_dim: int, seq_len: int, **kwargs):
|
||
|
"""Complex exponential positional embeddings for implicit long convolution filters."""
|
||
|
super().__init__()
|
||
|
|
||
|
self.seq_len = seq_len
|
||
|
# The time embedding fed to the filteres is normalized so that t_f = 1
|
||
|
t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
|
||
|
|
||
|
if emb_dim > 1:
|
||
|
bands = (emb_dim - 1) // 2
|
||
|
# To compute the right embeddings we use the "proper" linspace
|
||
|
t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
|
||
|
w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
|
||
|
|
||
|
f = torch.linspace(1e-4, bands - 1, bands)[None, None]
|
||
|
z = torch.exp(-1j * f * w)
|
||
|
z = torch.cat([t, z.real, z.imag], dim=-1)
|
||
|
self.z = nn.Parameter(z, requires_grad=False)
|
||
|
|
||
|
def forward(self, L):
|
||
|
return self.z[:, :L]
|
||
|
|
||
|
|
||
|
class ImplicitLongConvolution(nn.Module):
|
||
|
"""
|
||
|
Long convolution with implicit filter parameterized by an MLP.
|
||
|
|
||
|
Args:
|
||
|
hidden_size (int):
|
||
|
The number of expected features in the input and output.
|
||
|
l_max (int):
|
||
|
The maximum sequence length.
|
||
|
d_emb (Optional[int]):
|
||
|
The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine).
|
||
|
Defaults to 3.
|
||
|
d_hidden (Optional[int]):
|
||
|
The number of features in the hidden layer of the MLP. Defaults to 16.
|
||
|
|
||
|
Attributes:
|
||
|
pos_emb (`PositionalEmbedding`): The positional embedding layer.
|
||
|
mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_size: int,
|
||
|
l_max: int,
|
||
|
d_emb: int = 3,
|
||
|
d_hidden: int = 16,
|
||
|
**kwargs,
|
||
|
):
|
||
|
"""
|
||
|
Long convolution with implicit filter parameterized by an MLP.
|
||
|
|
||
|
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.hidden_size = hidden_size
|
||
|
self.d_emb = d_emb
|
||
|
|
||
|
assert (
|
||
|
d_emb % 2 != 0 and d_emb >= 3
|
||
|
), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)"
|
||
|
self.pos_emb = PositionalEmbedding(d_emb, l_max)
|
||
|
|
||
|
# final linear layer
|
||
|
self.mlp = nn.Sequential(
|
||
|
nn.Linear(d_emb, d_hidden),
|
||
|
torch.nn.ReLU(),
|
||
|
nn.Linear(d_hidden, hidden_size),
|
||
|
)
|
||
|
|
||
|
def filter(self, seq_len: int, *args, **kwargs):
|
||
|
k = self.mlp(self.pos_emb(seq_len))
|
||
|
|
||
|
return k.transpose(1, 2)
|
||
|
|
||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||
|
"""
|
||
|
Args:
|
||
|
x: (b, l, d) tensor
|
||
|
Returns:
|
||
|
y: (b, l, d) tensor
|
||
|
"""
|
||
|
x = x.transpose(1, 2)
|
||
|
k = self.filter(x.shape[-1])
|
||
|
y = fft_conv(x, k, dropout_mask=None, gelu=False)
|
||
|
|
||
|
y = y.transpose(1, 2)
|
||
|
return y.to(dtype=x.dtype)
|