91 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
		
			Vendored
		
	
	
	
			
		
		
	
	
			91 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
		
			Vendored
		
	
	
	
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| from typing import Optional
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| def naive_recurrent_abc(
 | |
|     q: torch.Tensor,
 | |
|     k: torch.Tensor,
 | |
|     v: torch.Tensor,
 | |
|     s: torch.Tensor,
 | |
|     g: Optional[torch.Tensor] = None,
 | |
|     scale: Optional[int] = None,
 | |
|     initial_state: Optional[torch.Tensor] = None,
 | |
|     output_final_state: Optional[bool] = False
 | |
| ) -> torch.Tensor:
 | |
|     dtype = q.dtype
 | |
| 
 | |
|     # [batch_size, n_heads, seq_len, n_slots]
 | |
|     if g is None:
 | |
|         z = s.float().logcumsumexp(2)
 | |
|         g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
 | |
|         s = torch.exp(s - z)
 | |
|     q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
 | |
|     B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
 | |
| 
 | |
|     hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
 | |
|     ok = torch.zeros_like(s)
 | |
| 
 | |
|     if scale is None:
 | |
|         scale = q.shape[-1] ** -0.5
 | |
| 
 | |
|     final_state = None
 | |
|     if initial_state is not None:
 | |
|         hk += initial_state[0]
 | |
| 
 | |
|     for i in range(T):
 | |
|         q_i = q[:, :, i] * scale
 | |
|         k_i = k[:, :, i]
 | |
|         v_i = s[:, :, i]
 | |
|         g_i = g[:, :, i].exp()
 | |
|         hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
 | |
|         ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
 | |
| 
 | |
|     qv = ok.softmax(-1)
 | |
|     hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
 | |
|     ov = torch.zeros_like(v)
 | |
|     if initial_state is not None:
 | |
|         hv += initial_state[1]
 | |
| 
 | |
|     for i in range(T):
 | |
|         q_i = qv[:, :, i]
 | |
|         k_i = s[:, :, i]
 | |
|         v_i = v[:, :, i]
 | |
|         g_i = g[:, :, i].exp()
 | |
|         hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
 | |
|         ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
 | |
| 
 | |
|     if output_final_state:
 | |
|         final_state = (hk, hv)
 | |
|     return ov.to(dtype), final_state
 | |
| 
 | |
| 
 | |
| def naive_cumsum_abc(
 | |
|     q: torch.Tensor,
 | |
|     k: torch.Tensor,
 | |
|     v: torch.Tensor,
 | |
|     s: torch.Tensor
 | |
| ) -> torch.Tensor:
 | |
|     """
 | |
|     A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper.
 | |
|     This is just for demonstration purposes, with no numerical stabilities guaranteed.
 | |
|     """
 | |
| 
 | |
|     dtype = q.dtype
 | |
|     q, k, v, s = map(lambda x: x.float(), (q, k, v, s))
 | |
| 
 | |
|     scale = q.shape[-1] ** -0.5
 | |
|     # [batch_size, n_heads, seq_len, n_slots]
 | |
|     s = (s - s.max(2, True)[0]).exp()
 | |
|     z = s.cumsum(2)
 | |
|     # [batch_size, n_heads, seq_len, n_slots, d_head]
 | |
|     K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
 | |
|     V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
 | |
|     # [batch_size, n_heads, seq_len, n_slots]
 | |
|     p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1)
 | |
|     # [batch_size, n_heads, seq_len, d_head]
 | |
|     o = torch.einsum('...m,...md->...d', p, V)
 | |
|     return o.to(dtype), None
 |