RWKV-Runner/finetune/lora/v6/fla/ops/delta_rule/naive.py
2024-05-28 22:35:47 +08:00

93 lines
3.1 KiB
Python
Vendored

# -*- coding: utf-8 -*-
import torch
from einops import rearrange
def delta_rule_recurrence(q, k, v, beta):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)
for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
return o
def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
q = q * (d_k ** -0.5)
v = v * beta[..., None]
k_beta = k * beta[..., None]
assert l % chunk_size == 0
# note that diagonal is masked.
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
for i in range(1, chunk_size):
attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
# u
k_cumsum = attn @ v
# w
k_cumdecay = attn @ k_beta
v = k_cumsum
S = k.new_zeros(b, h, d_k, d_v)
o = torch.zeros_like(v)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
for i in range(0, l // chunk_size):
q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
v_prime = k_cumdecay[:, :, i] @ S
v_new = v_i - v_prime
o_inter = q_i @ S
o[:, :, i] = o_inter + attn @ v_new
# chunk state update
S = S + k_i.transpose(-1, -2) @ v_new
return rearrange(o, 'b h n c d -> b h (n c) d')
if __name__ == '__main__':
B = 2
H = 4
L = 256
DK = 128
DV = 128
q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True)
k = (torch.randn(B, H, L, DK)).cuda()
k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True)
v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True)
beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True)
o = delta_rule_recurrence(q, k, v, beta)
do = torch.randn(B, H, L, DV).cuda()
o.backward(do, retain_graph=True)
q_grad, q.grad = q.grad, None
k_grad, k.grad = k.grad, None
v_grad, v.grad = v.grad, None
beta_grad, beta.grad = beta.grad, None
o2 = delta_rule_chunkwise(q, k, v, beta)
o2.backward(do)
assert torch.allclose(o, o2, atol=1e-4), breakpoint()
assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint()
assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint()
assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint()
assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint()
print("All passed!")