# -*- coding: utf-8 -*- # from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py import math import warnings from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from fla.modules.activations import ACT2FN from fla.utils import checkpoint try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: causal_conv1d_fn = None causal_conv1d_update = None def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None): seqlen = u.shape[-1] fft_size = 2 * seqlen k_f = torch.fft.rfft(k, n=fft_size) / fft_size if k_rev is not None: k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size k_f = k_f + k_rev_f.conj() u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) if len(u.shape) > 3: k_f = k_f.unsqueeze(1) y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] out = y + u if gelu: out = F.gelu(out) if dropout_mask is not None: return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) else: return out.to(dtype=u.dtype) @checkpoint def proj_then_conv1d( x: torch.Tensor, proj_weight: torch.Tensor, conv1d_weight: torch.Tensor, conv1d_bias: Optional[torch.Tensor] = None, cache: Optional[torch.Tensor] = None ) -> torch.Tensor: # We do matmul and transpose BLH -> HBL at the same time x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2]) if causal_conv1d_fn is None: raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.") if cache is None: x = causal_conv1d_fn( x=x, weight=rearrange(conv1d_weight, "d 1 w -> d w"), bias=conv1d_bias, activation="silu", ).transpose(1, 2) else: assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now" x = x.squeeze(-1) x = causal_conv1d_update( x=x, weight=rearrange(conv1d_weight, "d 1 w -> d w"), bias=conv1d_bias, cache=cache, activation="silu", ) return x class ShortConvolution(nn.Conv1d): """ Simple wrapper around `nn.Conv1d` that accepts dimension last. """ def __init__( self, hidden_size: int, kernel_size: int, bias: bool = False, activation: Optional[str] = 'silu', use_causal_conv: Optional[bool] = True ): super().__init__(in_channels=hidden_size, out_channels=hidden_size, kernel_size=kernel_size, groups=hidden_size, bias=bias, padding=kernel_size - 1) self.hidden_size = hidden_size self.activation = None if activation is not None: assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet." self.activation = activation if use_causal_conv: if causal_conv1d_fn is None: warnings.warn("Please install `causal-conv1d` to use causal convolutions, setting `use_causal_conv` to False.") use_causal_conv = False self.use_causal_conv = use_causal_conv def extra_repr(self): s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0,) * len(self.padding): s += ', padding={padding}' if self.dilation != (1,) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0,) * len(self.output_padding): s += ', output_padding={output_padding}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' if self.padding_mode != 'zeros': s += ', padding_mode={padding_mode}' if self.activation is not None: s += ', activation={activation}' if not self.use_causal_conv: s += ', use_causal_conv={use_causal_conv}' return s.format(**self.__dict__) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, cache: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: x (`torch.Tensor`): Tensor of shape `[batch_size, seq_len, hidden_size]` mask (`Optional[torch.Tensor]`): Attention mask dealing with padded positions. cache (`Optional[torch.Tensor]`): Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`, Returns: Tensor of shape `[batch_size, seq_len, hidden_size]`. The `cache` (if provided) is updated inplace. """ if mask is not None: x = x.mul_(mask.unsqueeze(-1)) if cache is not None and x.shape[1] == 1: return self.step(x, cache) x = rearrange(x, "b l d -> b d l") # Update state (B D W) if cache is not None: cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0))) if self.use_causal_conv: x = causal_conv1d_fn( x=x, weight=rearrange(self.weight, "d 1 w -> d w"), bias=self.bias, activation=self.activation, ) else: x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]] if self.activation is not None: x = ACT2FN[self.activation](x) return rearrange(x, "b d l -> b l d") def step( self, x: torch.Tensor, cache: torch.Tensor ): assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now" x = x.squeeze(1) if self.use_causal_conv: x = causal_conv1d_update( x=x, conv_state=cache, weight=rearrange(self.weight, "d 1 w -> d w"), bias=self.bias, activation=self.activation, ) else: dtype = x.dtype cache.copy_(torch.roll(cache, shifts=-1, dims=-1)) cache[:, :, -1] = x x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1) if self.bias is not None: x = x + self.bias if self.activation is not None: x = ACT2FN[self.activation](x).to(dtype=dtype) return x.unsqueeze(1) @property def state_size(self) -> int: return self.hidden_size * self.kernel_size class LongConvolution(nn.Module): """ LongConvolution applies a convolution operation on the input tensor using a fixed filter of length l_max. The filter is learned during training and is applied using FFT convolution. Args: hidden_size (int): The number of expected features in the input and output. l_max (int): The maximum sequence length. Returns: y: (b, l, d) tensor """ def __init__( self, hidden_size: int, l_max: int, **kwargs, ): """ Initializes the LongConvolution module. Args: hidden_size (int): The number of expected features in the input and output. l_max (int): The maximum sequence length. """ super().__init__() self.hidden_size = hidden_size self.filter = nn.Parameter(torch.randn(self.hidden_size, l_max), requires_grad=True) def forward(self, x: torch.Tensor, *args, **kwargs): """ Applies the LongConvolution operation on the input tensor. Args: x: (b, l, d) tensor Returns: y: (b, l, d) tensor """ x = x.transpose(1, 2) y = fft_conv(x, self.filter, dropout_mask=None, gelu=False) y = y.transpose(1, 2) return y.to(dtype=x.dtype) class PositionalEmbedding(nn.Module): def __init__(self, emb_dim: int, seq_len: int, **kwargs): """Complex exponential positional embeddings for implicit long convolution filters.""" super().__init__() self.seq_len = seq_len # The time embedding fed to the filteres is normalized so that t_f = 1 t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 if emb_dim > 1: bands = (emb_dim - 1) // 2 # To compute the right embeddings we use the "proper" linspace t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 f = torch.linspace(1e-4, bands - 1, bands)[None, None] z = torch.exp(-1j * f * w) z = torch.cat([t, z.real, z.imag], dim=-1) self.z = nn.Parameter(z, requires_grad=False) def forward(self, L): return self.z[:, :L] class ImplicitLongConvolution(nn.Module): """ Long convolution with implicit filter parameterized by an MLP. Args: hidden_size (int): The number of expected features in the input and output. l_max (int): The maximum sequence length. d_emb (Optional[int]): The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine). Defaults to 3. d_hidden (Optional[int]): The number of features in the hidden layer of the MLP. Defaults to 16. Attributes: pos_emb (`PositionalEmbedding`): The positional embedding layer. mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter. """ def __init__( self, hidden_size: int, l_max: int, d_emb: int = 3, d_hidden: int = 16, **kwargs, ): """ Long convolution with implicit filter parameterized by an MLP. """ super().__init__() self.hidden_size = hidden_size self.d_emb = d_emb assert ( d_emb % 2 != 0 and d_emb >= 3 ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)" self.pos_emb = PositionalEmbedding(d_emb, l_max) # final linear layer self.mlp = nn.Sequential( nn.Linear(d_emb, d_hidden), torch.nn.ReLU(), nn.Linear(d_hidden, hidden_size), ) def filter(self, seq_len: int, *args, **kwargs): k = self.mlp(self.pos_emb(seq_len)) return k.transpose(1, 2) def forward(self, x: torch.Tensor, *args, **kwargs): """ Args: x: (b, l, d) tensor Returns: y: (b, l, d) tensor """ x = x.transpose(1, 2) k = self.filter(x.shape[-1]) y = fft_conv(x, k, dropout_mask=None, gelu=False) y = y.transpose(1, 2) return y.to(dtype=x.dtype)