This commit is contained in:
15
finetune/lora/v6/fla/ops/retention/naive.py
vendored
Normal file
15
finetune/lora/v6/fla/ops/retention/naive.py
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def naive_retention(q, k, v):
|
||||
orig_type = q.dtype
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
_, n_heads, seq_len, d_head = q.shape
|
||||
s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2()
|
||||
n = q.new_tensor(range(seq_len), dtype=torch.float)
|
||||
n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n)
|
||||
s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype))
|
||||
o = torch.einsum('bhqk,bhkd->bhqd', s, v)
|
||||
return o.to(orig_type)
|
||||
Reference in New Issue
Block a user