93 lines
3.1 KiB
Python
Vendored
93 lines
3.1 KiB
Python
Vendored
# -*- coding: utf-8 -*-
|
|
|
|
import torch
|
|
from einops import rearrange
|
|
|
|
|
|
def delta_rule_recurrence(q, k, v, beta):
|
|
b, h, l, d_k = q.shape
|
|
d_v = v.shape[-1]
|
|
o = torch.zeros_like(v)
|
|
S = torch.zeros(b, h, d_k, d_v).to(v)
|
|
q = q * (d_k ** -0.5)
|
|
for i in range(l):
|
|
_k = k[:, :, i]
|
|
_q = q[:, :, i]
|
|
_v = v[:, :, i].clone()
|
|
beta_i = beta[:, :, i]
|
|
_v = _v - (S.clone() * _k[..., None]).sum(-2)
|
|
_v = _v * beta_i[..., None]
|
|
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
|
|
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
|
|
return o
|
|
|
|
|
|
def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
|
|
b, h, l, d_k = q.shape
|
|
d_v = v.shape[-1]
|
|
q = q * (d_k ** -0.5)
|
|
v = v * beta[..., None]
|
|
k_beta = k * beta[..., None]
|
|
|
|
assert l % chunk_size == 0
|
|
|
|
# note that diagonal is masked.
|
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
|
|
q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
|
|
attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
|
|
|
|
for i in range(1, chunk_size):
|
|
attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
|
|
|
|
attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
|
|
# u
|
|
k_cumsum = attn @ v
|
|
# w
|
|
k_cumdecay = attn @ k_beta
|
|
|
|
v = k_cumsum
|
|
S = k.new_zeros(b, h, d_k, d_v)
|
|
o = torch.zeros_like(v)
|
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
|
|
for i in range(0, l // chunk_size):
|
|
q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i]
|
|
attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
|
|
v_prime = k_cumdecay[:, :, i] @ S
|
|
v_new = v_i - v_prime
|
|
o_inter = q_i @ S
|
|
o[:, :, i] = o_inter + attn @ v_new
|
|
# chunk state update
|
|
S = S + k_i.transpose(-1, -2) @ v_new
|
|
|
|
return rearrange(o, 'b h n c d -> b h (n c) d')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
B = 2
|
|
H = 4
|
|
L = 256
|
|
DK = 128
|
|
DV = 128
|
|
q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True)
|
|
k = (torch.randn(B, H, L, DK)).cuda()
|
|
k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True)
|
|
v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True)
|
|
beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True)
|
|
|
|
o = delta_rule_recurrence(q, k, v, beta)
|
|
do = torch.randn(B, H, L, DV).cuda()
|
|
o.backward(do, retain_graph=True)
|
|
q_grad, q.grad = q.grad, None
|
|
k_grad, k.grad = k.grad, None
|
|
v_grad, v.grad = v.grad, None
|
|
beta_grad, beta.grad = beta.grad, None
|
|
|
|
o2 = delta_rule_chunkwise(q, k, v, beta)
|
|
o2.backward(do)
|
|
assert torch.allclose(o, o2, atol=1e-4), breakpoint()
|
|
assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint()
|
|
assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint()
|
|
assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint()
|
|
assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint()
|
|
print("All passed!")
|