139 lines
4.2 KiB
Python
Vendored
139 lines
4.2 KiB
Python
Vendored
import triton
|
|
import triton.language as tl
|
|
|
|
inv_ln2 = 1.44269504
|
|
|
|
|
|
|
|
@triton.jit
|
|
def fwd_decay_cumsum(
|
|
g,
|
|
g_o,
|
|
s_qk_h,
|
|
s_qk_t,
|
|
s_qk_d,
|
|
B,
|
|
H,
|
|
T,
|
|
scale,
|
|
BT: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
DK: tl.constexpr
|
|
):
|
|
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
|
p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
|
cum_decay = tl.zeros([BK], dtype=tl.float32)
|
|
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
|
|
|
for i in range(BT):
|
|
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
|
cum_decay += _g * inv_ln2
|
|
tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)
|
|
p_g += DK
|
|
p_go += DK
|
|
|
|
@triton.jit
|
|
def prepare_qg_kg(
|
|
q,
|
|
k,
|
|
g,
|
|
qg,
|
|
kg,
|
|
s_qk_h,
|
|
s_qk_t,
|
|
s_qk_d,
|
|
B,
|
|
H,
|
|
T,
|
|
scale,
|
|
BT: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
DK: tl.constexpr
|
|
):
|
|
|
|
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
|
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
|
p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
|
p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
|
p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
|
|
|
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
|
|
|
last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))
|
|
|
|
for i in range(BT):
|
|
_q = tl.load(p_q, mask=mask, other=0)
|
|
_k = tl.load(p_k, mask=mask, other=0)
|
|
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
|
_q *= tl.math.exp2(_g) * scale
|
|
_k *= tl.math.exp2(last_decay - _g)
|
|
tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)
|
|
tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)
|
|
p_q += DK
|
|
p_g += DK
|
|
p_k += DK
|
|
p_kg += DK
|
|
p_qg += DK
|
|
|
|
|
|
@triton.jit
|
|
def bwd_decay_global_cumsum(
|
|
dq_inner,
|
|
dq_inter,
|
|
dk_inner,
|
|
dk_inter,
|
|
q, k, g, dg,
|
|
s_qk_h,
|
|
s_qk_t,
|
|
s_qk_d,
|
|
B,
|
|
H,
|
|
T,
|
|
scale,
|
|
BT: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
DK: tl.constexpr
|
|
):
|
|
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
|
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
|
p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
|
p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
|
p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
|
p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
|
p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
|
p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
|
cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
|
|
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
|
last_g = tl.zeros([BK], dtype=tl.float32)
|
|
for j in range(BT-1, -1, -1):
|
|
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
|
if j == (BT-1):
|
|
last_g = _g
|
|
_dq1 = tl.load(p_dq_inner, mask=mask, other=0)
|
|
_dq2 = tl.load(p_dq_inter, mask=mask, other=0)
|
|
_dq2 *= tl.math.exp2(_g)
|
|
_dq = _dq1 + _dq2
|
|
tl.store(p_dq_inter, _dq, mask=mask)
|
|
_dk1 = tl.load(p_dk_inner, mask=mask, other=0)
|
|
_dk2 = tl.load(p_dk_inter, mask=mask, other=0)
|
|
_dk2 *= tl.math.exp2(last_g - _g)
|
|
_dk = _dk1 + _dk2
|
|
tl.store(p_dk_inter, _dk, mask=mask)
|
|
_q = tl.load(p_q, mask=mask, other=0)
|
|
_k = tl.load(p_k, mask=mask, other=0)
|
|
_dg = _dq * _q - _dk * _k
|
|
cum_grad_dg += _dg
|
|
tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
|
|
p_g -= DK
|
|
p_k -= DK
|
|
p_q -= DK
|
|
p_dq_inner -= DK
|
|
p_dk_inner -= DK
|
|
p_dq_inter -= DK
|
|
p_dk_inter -= DK
|
|
p_dg -= DK
|
|
|