117 lines
3.7 KiB
Python
Vendored
117 lines
3.7 KiB
Python
Vendored
# -*- coding: utf-8 -*-
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from fla.ops.gla.recurrent_fuse import fused_recurrent_gla
|
|
|
|
|
|
def ceildiv(a, b):
|
|
return -(a // -b)
|
|
|
|
|
|
def naive_recurrent_gla(
|
|
q,
|
|
k,
|
|
v,
|
|
gk,
|
|
initial_state=None,
|
|
output_final_state=False,
|
|
causal=True
|
|
):
|
|
orig_dtype = q.dtype
|
|
q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk))
|
|
batch_size, n_heads, seq_len, d_head_k = q.shape
|
|
_, _, _, d_head_v = v.shape
|
|
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
|
|
o = torch.zeros_like(v)
|
|
scale = d_head_k ** -0.5
|
|
|
|
if initial_state is not None:
|
|
h += initial_state
|
|
|
|
for i in range(seq_len):
|
|
q_i = q[:, :, i, :] * scale
|
|
k_i = k[:, :, i]
|
|
v_i = v[:, :, i, :]
|
|
gk_i = gk[:, :, i].exp()
|
|
kv_i = k_i[..., None] * v_i[..., None, :]
|
|
h = h * gk_i[..., None] + kv_i
|
|
o_i = (q_i[..., None] * h).sum(-2)
|
|
o[:, :, i] = o_i
|
|
|
|
if causal:
|
|
return o.to(orig_dtype), h
|
|
else:
|
|
o_reverse = torch.zeros_like(v)
|
|
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
|
|
for i in range(seq_len-1, -1, -1):
|
|
q_i = q[:, :, i, :] * scale
|
|
k_i = k[:, :, i]
|
|
v_i = v[:, :, i, :]
|
|
gk_i = gk[:, :, i].exp()
|
|
kv_i = k_i[..., None] * v_i[..., None, :]
|
|
h = h * gk_i[..., None] + kv_i
|
|
o_i = (q_i[..., None] * h).sum(-2)
|
|
o_reverse[:, :, i] = o_i
|
|
|
|
return o, o_reverse
|
|
|
|
|
|
if __name__ == "__main__":
|
|
B = 4
|
|
H = 4
|
|
L = 512
|
|
D = 128
|
|
dtype = torch.float32
|
|
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
|
|
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
|
|
v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True)
|
|
g = F.logsigmoid(torch.rand(B, H, L, D)).cuda(
|
|
).clamp_min(-1).to(torch.float32).requires_grad_(True)
|
|
|
|
do = torch.rand_like(v).cuda()
|
|
do2 = torch.rand_like(v).cuda()
|
|
intial_state = torch.rand(B, H, D, D).cuda()
|
|
|
|
ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False)
|
|
|
|
ref.backward(do, retain_graph=True)
|
|
ref_rev.backward(do2, retain_graph=True)
|
|
|
|
ref_dq, q.grad = q.grad.clone(), None
|
|
ref_dk, k.grad = k.grad.clone(), None
|
|
ref_dv, v.grad = v.grad.clone(), None
|
|
ref_dg, g.grad = g.grad.clone(), None
|
|
|
|
tri, tri_rev = fused_recurrent_gla(
|
|
q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False)
|
|
tri.backward(do, retain_graph=True)
|
|
tri_rev.backward(do2, retain_graph=True)
|
|
tri_dq, q.grad = q.grad.clone(), None
|
|
tri_dk, k.grad = k.grad.clone(), None
|
|
tri_dv, v.grad = v.grad.clone(), None
|
|
tri_dg, g.grad = g.grad.clone(), None
|
|
|
|
assert ref.allclose(tri, 0, 1e-5), breakpoint()
|
|
assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint()
|
|
assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
|
|
assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
|
|
assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
|
|
assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
|
|
|
|
# tri = fused_chunk_gla(q, k, v, g)
|
|
# tri.backward(do, retain_graph=True)
|
|
# tri_dq, q.grad = q.grad.clone(), None
|
|
# tri_dk, k.grad = k.grad.clone(), None
|
|
# tri_dv, v.grad = v.grad.clone(), None
|
|
# tri_dg, g.grad = g.grad.clone(), None
|
|
|
|
# assert ref.allclose(tri, 0, 1e-5), breakpoint()
|
|
# assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
|
|
# assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
|
|
# assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
|
|
# assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
|
|
# breakpoint()
|
|
print("Pass")
|