import triton import triton.language as tl inv_ln2 = 1.44269504 @triton.jit def fwd_decay_cumsum( g, g_o, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr ): i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) cum_decay = tl.zeros([BK], dtype=tl.float32) mask = (i_k * BK + tl.arange(0, BK)) < DK for i in range(BT): _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) cum_decay += _g * inv_ln2 tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask) p_g += DK p_go += DK @triton.jit def prepare_qg_kg( q, k, g, qg, kg, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr ): i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) mask = (i_k * BK + tl.arange(0, BK)) < DK last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK)) for i in range(BT): _q = tl.load(p_q, mask=mask, other=0) _k = tl.load(p_k, mask=mask, other=0) _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) _q *= tl.math.exp2(_g) * scale _k *= tl.math.exp2(last_decay - _g) tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask) tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask) p_q += DK p_g += DK p_k += DK p_kg += DK p_qg += DK @triton.jit def bwd_decay_global_cumsum( dq_inner, dq_inter, dk_inner, dk_inter, q, k, g, dg, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale, BT: tl.constexpr, BK: tl.constexpr, DK: tl.constexpr ): i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK cum_grad_dg = tl.zeros([BK], dtype=tl.float32) mask = (i_k * BK + tl.arange(0, BK)) < DK last_g = tl.zeros([BK], dtype=tl.float32) for j in range(BT-1, -1, -1): _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) if j == (BT-1): last_g = _g _dq1 = tl.load(p_dq_inner, mask=mask, other=0) _dq2 = tl.load(p_dq_inter, mask=mask, other=0) _dq2 *= tl.math.exp2(_g) _dq = _dq1 + _dq2 tl.store(p_dq_inter, _dq, mask=mask) _dk1 = tl.load(p_dk_inner, mask=mask, other=0) _dk2 = tl.load(p_dk_inter, mask=mask, other=0) _dk2 *= tl.math.exp2(last_g - _g) _dk = _dk1 + _dk2 tl.store(p_dk_inter, _dk, mask=mask) _q = tl.load(p_q, mask=mask, other=0) _k = tl.load(p_k, mask=mask, other=0) _dg = _dq * _q - _dk * _k cum_grad_dg += _dg tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask) p_g -= DK p_k -= DK p_q -= DK p_dq_inner -= DK p_dk_inner -= DK p_dq_inter -= DK p_dk_inter -= DK p_dg -= DK