# -*- coding: utf-8 -*- import torch from einops import rearrange def torch_chunk_linear_attn(q, k, v, chunk_size=64): q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] **-0.5) k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) kv = k.transpose(-1, -2) @ v kv = kv.cumsum(2) kv = torch.cat([ torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1] ], dim=2) inter = q @ kv intra = ((q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)) @ v o = inter + intra return rearrange(o, 'b h n c d -> b h (n c) d')