# -*- coding: utf-8 -*- import torch from einops import rearrange from fla.ops.rebased.parallel import parallel_rebased def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True): if use_scale: q = q * (q.shape[-1] ** -0.5) attn = q @ k.transpose(-2, -1) attn = (attn ** 2) attn.masked_fill_(~torch.tril(torch.ones( q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) o = attn @ v if use_norm: z = attn.sum(-1) return o / (z[..., None] + 1e-6) else: return o if __name__ == "__main__": B = 4 H = 4 L = 128 # D = 15 dtype = torch.float32 q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True) do = torch.randn_like(v).cuda() ref = naive_parallel_rebased(q, k, v, True, True) ref.backward(do, retain_graph=True) ref_dq, q.grad = q.grad.clone(), None ref_dk, k.grad = k.grad.clone(), None ref_dv, v.grad = v.grad.clone(), None # tri = naive_chunk_based(q, k, v) # tri.backward(do, retain_graph=True) # tri_dq, q.grad = q.grad.clone(), None # tri_dk, k.grad = k.grad.clone(), None # tri_dv, v.grad = v.grad.clone(), None # assert ref.allclose(tri, 0, 1e-4), breakpoint() # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() tri = parallel_rebased(q, k, v, 1e-6, True, True) tri.backward(do, retain_graph=True) tri_dq, q.grad = q.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dv, v.grad = v.grad.clone(), None print((ref-tri).abs().max()) print((ref_dq-tri_dq).abs().max()) print((ref_dk-tri_dk).abs().max()) print((ref_dv-tri_dv).abs().max()) # assert ref.allclose(tri, 0, 1e-4), breakpoint() # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() # tri = parallel_based(q, k, v, True, True) # tri.backward(do, retain_graph=True) # tri_dq, q.grad = q.grad.clone(), None # tri_dk, k.grad = k.grad.clone(), None # tri_dv, v.grad = v.grad.clone(), None # print((ref-tri).abs().max()) # print((ref_dq-tri_dq).abs().max()) # print((ref_dk-tri_dk).abs().max()) # print((ref_dv-tri_dv).abs().max()) # assert ref.allclose(tri, 0, 1e-4), breakpoint() # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()