81 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
		
			Vendored
		
	
	
	
			
		
		
	
	
			81 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
		
			Vendored
		
	
	
	
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| import torch
 | |
| from einops import rearrange
 | |
| 
 | |
| from fla.ops.rebased.parallel import parallel_rebased
 | |
| 
 | |
| def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True):
 | |
|     if use_scale:
 | |
|         q = q * (q.shape[-1] ** -0.5)
 | |
|     attn = q @ k.transpose(-2, -1)
 | |
|     attn = (attn ** 2)
 | |
|     attn.masked_fill_(~torch.tril(torch.ones(
 | |
|         q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
 | |
|     o = attn @ v
 | |
|     if use_norm:
 | |
|         z = attn.sum(-1)
 | |
|         return o / (z[..., None] + 1e-6)
 | |
|     else:
 | |
|         return o
 | |
| 
 | |
|     
 | |
| if __name__ == "__main__":
 | |
|     B = 4
 | |
|     H = 4
 | |
|     L = 128
 | |
|     # D = 15
 | |
|     dtype = torch.float32
 | |
|     q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
 | |
|     k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
 | |
|     v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True)
 | |
| 
 | |
|     do = torch.randn_like(v).cuda()
 | |
|     ref = naive_parallel_rebased(q, k, v, True, True)
 | |
|     ref.backward(do, retain_graph=True)
 | |
|     ref_dq, q.grad = q.grad.clone(), None
 | |
|     ref_dk, k.grad = k.grad.clone(), None
 | |
|     ref_dv, v.grad = v.grad.clone(), None
 | |
| 
 | |
|     # tri = naive_chunk_based(q, k, v)
 | |
|     # tri.backward(do, retain_graph=True)
 | |
|     # tri_dq, q.grad = q.grad.clone(), None
 | |
|     # tri_dk, k.grad = k.grad.clone(), None
 | |
|     # tri_dv, v.grad = v.grad.clone(), None
 | |
| 
 | |
|     # assert ref.allclose(tri, 0, 1e-4), breakpoint()
 | |
|     # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
 | |
|     # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
 | |
|     # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
 | |
| 
 | |
|     tri = parallel_rebased(q, k, v, 1e-6, True, True)
 | |
|     tri.backward(do, retain_graph=True)
 | |
|     tri_dq, q.grad = q.grad.clone(), None
 | |
|     tri_dk, k.grad = k.grad.clone(), None
 | |
|     tri_dv, v.grad = v.grad.clone(), None
 | |
|     print((ref-tri).abs().max())
 | |
|     print((ref_dq-tri_dq).abs().max())
 | |
|     print((ref_dk-tri_dk).abs().max())
 | |
|     print((ref_dv-tri_dv).abs().max())
 | |
| 
 | |
|     # assert ref.allclose(tri, 0, 1e-4), breakpoint()
 | |
|     # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
 | |
|     # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
 | |
|     # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
 | |
| 
 | |
|     # tri = parallel_based(q, k, v, True, True)
 | |
|     # tri.backward(do, retain_graph=True)
 | |
|     # tri_dq, q.grad = q.grad.clone(), None
 | |
|     # tri_dk, k.grad = k.grad.clone(), None
 | |
|     # tri_dv, v.grad = v.grad.clone(), None
 | |
| 
 | |
|     # print((ref-tri).abs().max())
 | |
|     # print((ref_dq-tri_dq).abs().max())
 | |
|     # print((ref_dk-tri_dk).abs().max())
 | |
|     # print((ref_dv-tri_dv).abs().max())
 | |
| 
 | |
|     # assert ref.allclose(tri, 0, 1e-4), breakpoint()
 | |
|     # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
 | |
|     # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
 | |
|     # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
 |