This commit is contained in:
80
finetune/lora/v6/fla/ops/rebased/naive.py
vendored
Normal file
80
finetune/lora/v6/fla/ops/rebased/naive.py
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from fla.ops.rebased.parallel import parallel_rebased
|
||||
|
||||
def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True):
|
||||
if use_scale:
|
||||
q = q * (q.shape[-1] ** -0.5)
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = (attn ** 2)
|
||||
attn.masked_fill_(~torch.tril(torch.ones(
|
||||
q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
|
||||
o = attn @ v
|
||||
if use_norm:
|
||||
z = attn.sum(-1)
|
||||
return o / (z[..., None] + 1e-6)
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
B = 4
|
||||
H = 4
|
||||
L = 128
|
||||
# D = 15
|
||||
dtype = torch.float32
|
||||
q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
|
||||
k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
|
||||
v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True)
|
||||
|
||||
do = torch.randn_like(v).cuda()
|
||||
ref = naive_parallel_rebased(q, k, v, True, True)
|
||||
ref.backward(do, retain_graph=True)
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
|
||||
# tri = naive_chunk_based(q, k, v)
|
||||
# tri.backward(do, retain_graph=True)
|
||||
# tri_dq, q.grad = q.grad.clone(), None
|
||||
# tri_dk, k.grad = k.grad.clone(), None
|
||||
# tri_dv, v.grad = v.grad.clone(), None
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
|
||||
|
||||
tri = parallel_rebased(q, k, v, 1e-6, True, True)
|
||||
tri.backward(do, retain_graph=True)
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
print((ref-tri).abs().max())
|
||||
print((ref_dq-tri_dq).abs().max())
|
||||
print((ref_dk-tri_dk).abs().max())
|
||||
print((ref_dv-tri_dv).abs().max())
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
|
||||
|
||||
# tri = parallel_based(q, k, v, True, True)
|
||||
# tri.backward(do, retain_graph=True)
|
||||
# tri_dq, q.grad = q.grad.clone(), None
|
||||
# tri_dk, k.grad = k.grad.clone(), None
|
||||
# tri_dv, v.grad = v.grad.clone(), None
|
||||
|
||||
# print((ref-tri).abs().max())
|
||||
# print((ref_dq-tri_dq).abs().max())
|
||||
# print((ref_dk-tri_dk).abs().max())
|
||||
# print((ref_dv-tri_dv).abs().max())
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
|
||||
Reference in New Issue
Block a user