This commit is contained in:
5
finetune/lora/v6/fla/ops/simple_gla/README.md
vendored
Normal file
5
finetune/lora/v6/fla/ops/simple_gla/README.md
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
- Simple GLA
|
||||
|
||||
Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA.
|
||||
|
||||
$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.
|
||||
8
finetune/lora/v6/fla/ops/simple_gla/__init__.py
vendored
Normal file
8
finetune/lora/v6/fla/ops/simple_gla/__init__.py
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_simple_gla
|
||||
|
||||
__all__ = [
|
||||
'chunk_simple_gla'
|
||||
]
|
||||
|
||||
415
finetune/lora/v6/fla/ops/simple_gla/chunk.py
vendored
Normal file
415
finetune/lora/v6/fla/ops/simple_gla/chunk.py
vendored
Normal file
@@ -0,0 +1,415 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def normalize_output(q, k, o):
|
||||
k = k.transpose(-2, -1)
|
||||
k = k.cumsum(-1)
|
||||
k = k.transpose(-2, -1)
|
||||
z = (q * k).sum(-1, keepdim=True)
|
||||
return o / (z + 1e-5)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_simple_gla_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V,
|
||||
(K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i_t in range(NT):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(
|
||||
v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,
|
||||
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)
|
||||
b_h *= tl.math.exp2(b_g_last)
|
||||
b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))
|
||||
b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(
|
||||
final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_simple_gla_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_s = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(
|
||||
k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,
|
||||
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT]
|
||||
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
b_o += tl.dot(b_q, b_h, allow_tf32=False)
|
||||
b_s += tl.dot(b_q, b_k, allow_tf32=False)
|
||||
|
||||
p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
|
||||
b_g = tl.load(p_g)
|
||||
b_o = b_o * tl.math.exp2(b_g)[:, None]
|
||||
b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V),
|
||||
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V),
|
||||
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_simple_gla_bwd_kernel_dh(
|
||||
q,
|
||||
g,
|
||||
do,
|
||||
dh,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V,
|
||||
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T +
|
||||
i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)
|
||||
# [BT, V]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))
|
||||
b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_simple_gla_bwd_kernel_dqkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
n_bh = tl.num_programs(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T),
|
||||
(s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K),
|
||||
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False)
|
||||
p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
|
||||
b_g = tl.load(p_g)
|
||||
b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)
|
||||
mask = tl.math.exp2(b_g[None, :] - b_g[:, None])
|
||||
mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)
|
||||
b_s = b_s * mask
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(
|
||||
v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t),
|
||||
(i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V),
|
||||
(s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V),
|
||||
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
# [BT, BV]
|
||||
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + \
|
||||
tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
b_dq = b_dq * tl.math.exp2(b_g)[:, None]
|
||||
b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]
|
||||
b_ds = b_ds * tl.trans(mask)
|
||||
b_ds = b_ds.to(b_k.dtype)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
|
||||
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K),
|
||||
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K),
|
||||
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class SimpleGLAFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, g, initial_state, output_final_state):
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT = 64
|
||||
BK, BV = min(64, triton.next_power_of_2(K)), min(
|
||||
64, triton.next_power_of_2(V))
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
scale = K ** -0.5
|
||||
|
||||
BT = 64
|
||||
assert T % BT == 0, 'sequence length must be divisible by BT'
|
||||
g = g.reshape(B, H, -1, BT)
|
||||
g = g.cumsum(-1) * 1.44269504
|
||||
g = g.reshape(B, H, -1)
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
|
||||
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_simple_gla_fwd_kernel_h[grid](
|
||||
k, v, h, g, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=output_final_state,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
grid = (NV, NT, B * H)
|
||||
o = torch.empty_like(v)
|
||||
chunk_simple_gla_fwd_kernel_o[grid](
|
||||
q, k, v, h, g, o,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, h, g)
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, d_ht=None):
|
||||
q, k, v, h, g = ctx.saved_tensors
|
||||
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT = 64
|
||||
BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(
|
||||
32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
scale = K ** -0.5
|
||||
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_simple_gla_bwd_kernel_dh[grid](
|
||||
q, g, do, dh,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
grid = (NK, NT, B * H)
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = v.new_empty(NK, *v.shape)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
chunk_simple_gla_bwd_kernel_dqkv[grid](
|
||||
q, k, v, h, g, do, dh, dq, dk, dv,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dv = dv.sum(0)
|
||||
dg = (dq * q - dk * k).sum(-1)
|
||||
|
||||
def rev_cumsum(x):
|
||||
cumsum_x = x.cumsum(-1)
|
||||
rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x
|
||||
return rev_cumsum_x + x
|
||||
dg = rev_cumsum(dg)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None
|
||||
|
||||
|
||||
def chunk_simple_gla(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor, # log decay
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
g = g.float()
|
||||
o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
52
finetune/lora/v6/fla/ops/simple_gla/naive.py
vendored
Normal file
52
finetune/lora/v6/fla/ops/simple_gla/naive.py
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def torch_simple_gla(q, k, v, g, chunk_size=64):
|
||||
q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5)
|
||||
k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size)
|
||||
g = g.cumsum(-1)
|
||||
kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None])
|
||||
S = torch.zeros_like(kv)
|
||||
|
||||
for i in range(1, g.shape[-2]):
|
||||
S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1]
|
||||
|
||||
inter = (q * g[..., None].exp()) @ S
|
||||
attn = q @ k.transpose(-1, -2)
|
||||
attn = attn * (g[..., None] - g[..., None, :]).exp()
|
||||
attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
|
||||
intra = attn @ v
|
||||
o = inter + intra
|
||||
return rearrange(o, 'b h n c d -> b h (n c) d')
|
||||
|
||||
|
||||
def torch_simple_gla_recurrent(q, k, v, g, chunk_size=64):
|
||||
# q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5)
|
||||
# k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
# v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
# g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size)
|
||||
# g = g.cumsum(-1)
|
||||
# kv = k.transpose(-1, -2) @ v
|
||||
|
||||
B, H, T, DK = q.shape
|
||||
q = q * (DK ** -0.5)
|
||||
_, _, _, DV = v.shape
|
||||
S = torch.zeros(B, H, DK, DV).to(q)
|
||||
o = torch.zeros(B, H, T, DV).to(q)
|
||||
for i in range(T):
|
||||
gate = g[:, :, i].exp()
|
||||
key = k[:, :, i]
|
||||
value = v[:, :, i]
|
||||
kv = key.unsqueeze(-1) * value.unsqueeze(-2)
|
||||
S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv
|
||||
q_i = q[:, :, i, :]
|
||||
o_i = (q_i.unsqueeze(-1) * S).sum(-2)
|
||||
o[:, :, i] = o_i
|
||||
|
||||
return o
|
||||
|
||||
Reference in New Issue
Block a user