This commit is contained in:
josc146
2024-05-28 22:35:47 +08:00
parent 3488d22d22
commit f05a4acb04
138 changed files with 29047 additions and 334 deletions

View File

@@ -0,0 +1,4 @@
- Delta Rule
The implementation of delta rule described in https://arxiv.org/abs/2102.11174

View File

@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
from .chunk_fuse import fused_chunk_delta_rule
from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule
from .chunk import chunk_delta_rule
__all__ = [
'fused_chunk_delta_rule',
'fused_recurrent_linear_attn_delta_rule',
'chunk_delta_rule'
]

View File

@@ -0,0 +1,544 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
import torch
import triton
import triton.language as tl
from fla.ops.utils import contiguous
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.ops.delta_rule.wy_fast import fwd_recompute_w_u, fwd_prepare_wy_repr, bwd_prepare_wy_repr
from fla.ops.delta_rule.chunk_fuse import fused_chunk_delta_rule_fwd, fused_chunk_delta_rule_bwd
# from fla.ops.delta_rule.utils import bwd_prepare_wy_repr
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_prepare_dv_kernel(
q,
k,
do,
dv,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
scale,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_k.dtype)
b_A += tl.dot(b_k, b_q, allow_tf32=False)
b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A , 0).to(do.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.dot(b_A, b_do, allow_tf32=False)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
def fwd_prepare_dv(q, k, do, BT):
dv = torch.empty_like(do)
B, H, T, K, V = *k.shape, do.shape[-1]
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
fwd_prepare_dv_kernel[(NT, B*H)](
q, k, do, dv,
k.stride(1), k.stride(2), k.stride(3),
do.stride(1), do.stride(2), do.stride(3),
T, K, V, K**-0.5, BT, BK, BV
)
return dv
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_delta_rule_fwd_kernel_h(
k,
v,
d,
v_new,
h,
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32)
# since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
for i_c in range(tl.cdiv(BT, BC)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BK]
b_d = tl.load(p_d, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)
# [BK, BV]
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
b_h += b_h_cumsum
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_linear_attn_fwd_kernel_o(
q,
k,
v,
h,
o,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_s = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_o += tl.dot(b_q, b_h, allow_tf32=False)
b_s += tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False))
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_delta_rule_bwd_kernel_dhu(
q,
k,
d,
do,
dh,
dv,
dv2,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_d = tl.load(p_d, boundary_check=(0, 1))
# [BT, V]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
# b_s = tl.dot(b_k, b_q, allow_tf32=False)
# b_s = tl.where(m_s, b_s, 0)
# b_dv = tl.dot(b_s.to(b_do.dtype), b_do, allow_tf32=False) + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BK, BV]
b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False)
b_dh += b_dh_tmp
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_delta_rule_bwd_kernel_dqkw(
q,
k,
v,
w,
h,
do,
dh,
dq,
dk,
dv,
dw,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
o_i = tl.arange(0, BT)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BT]
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dw += tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
# [BT, BT]
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
# [BT, BK]
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state):
B, H, T, K, V = *k.shape, u.shape[-1]
BK = triton.next_power_of_2(K)
assert BK <= 256, "current kernel does not support head dimension larger than 256."
BV = 16 if BK > 128 else 32
BV = 64 if BK <= 64 else BV
BC = 16 if BK > 128 else 32
BC = 64 if BK <= 64 else BC
BC = min(BT, BC)
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
h = k.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
v_new = torch.empty_like(u)
chunk_delta_rule_fwd_kernel_h[grid](
k, u, w, v_new, h, initial_state, final_state,
k.stride(1), k.stride(2), k.stride(3),
u.stride(1), u.stride(2), u.stride(3),
h.stride(1), h.stride(2),
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None,
)
return h, v_new
def chunk_bwd_dhu_fn(q, k, w, do, dv, BT):
B, H, T, K, V = *q.shape, do.shape[-1]
BK = triton.next_power_of_2(K)
assert BK <= 256, "current kernel does not support head dimension being larger than 256."
BV = 16 if BK > 128 else 32
BV = 64 if BK <= 64 else BV
BC = 16 if BK > 128 else 32
BC = 64 if BK <= 64 else BC
BC = min(BT, BC)
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
dh = q.new_empty(B, H, NT * K, V)
# dv_new = torch.empty_like(do)
grid = (NK, NV, B * H)
dv2 = torch.empty_like(dv)
chunk_delta_rule_bwd_kernel_dhu[grid](
q, k, w, do, dh, dv, dv2,
q.stride(1), q.stride(2), q.stride(3),
do.stride(1), do.stride(2), do.stride(3),
dh.stride(1), dh.stride(2),
K**-0.5,
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
)
return dh, dv2
def chunk_fwd_o_fn(q, k, v_new, h, BT):
B, H, T, K, V = *q.shape, v_new.shape[-1]
BK = triton.next_power_of_2(K)
o = torch.empty_like(v_new)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(K), 64)
NV = triton.cdiv(V, BV)
NT = triton.cdiv(T, BT)
grid = (NV, NT, B * H)
chunk_linear_attn_fwd_kernel_o[grid](
q, k, v_new, h, o,
q.stride(1), q.stride(2), q.stride(3),
v_new.stride(1), v_new.stride(2), v_new.stride(3),
h.stride(1), h.stride(2),
scale=K**-0.5,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
)
return o
def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT):
B, H, T, K, V = *q.shape, v_new.shape[-1]
BK = triton.next_power_of_2(K)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
NV = triton.cdiv(V, BV)
NT = triton.cdiv(T, BT)
grid = (NV, NT, B * H)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dw = torch.empty_like(w)
chunk_delta_rule_bwd_kernel_dqkw[grid](
q, k, v_new, w, h, do, dh, dq, dk, du, dw,
q.stride(1), q.stride(2), q.stride(3),
v_new.stride(1), v_new.stride(2), v_new.stride(3),
dh.stride(1), dh.stride(2),
scale = K ** -0.5,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
)
return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype)
class ChunkDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@contiguous
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):
### obtain WY representation. u is actually the new v.
w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
# ### forward_h
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)
## obtain output
o = chunk_fwd_o_fn(q, k, v_new, h, BT)
# save memory
if checkpoint_level == 1:
h, v_new = None, None
ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state)
ctx.BT = BT
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_ht=None):
q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
scale = q.shape[-1] ** -0.5
BT = ctx.BT
w, u = fwd_recompute_w_u(k, v, beta, A, BT)
# checkpont_level=1, recomputation.
if h is None:
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None)
dv = fwd_prepare_dv(q, k, do, BT)
dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)
dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)
dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT)
dk.add_(dk2)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None
def chunk_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
BT: int,
initial_state: torch.Tensor = None,
output_final_state: bool = False
):
assert q.dtype == k.dtype == v.dtype
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
return o, final_state

View File

@@ -0,0 +1,419 @@
# -*- coding: utf-8 -*-
from typing import Tuple
import torch
import triton
import triton.language as tl
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8)
],
key=["BT", "BK"],
)
@triton.jit
def fused_chunk_delta_rule_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_K]
v, # value [B, H, L, D_head_V]
v_new,
d, # decay [B, H, L, D_head_K]
o, # output [B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# [BT, BT]
m_s = o_i[:, None] >= o_i[None, :]
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i in range(0, tl.cdiv(T, BT)):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_d = tl.load(p_d, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_k.dtype)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
# [BT, BV]
b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)
b_v = b_v - b_v_prime
tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)
if CHECK and i == 0:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
else:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_v_new = tl.advance(p_v_new, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
p_d = tl.advance(p_d, (BT, 0))
if STORE_FINAL_STATE:
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fused_chunk_delta_rule_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
d, # decay [B, H, L, D_head_K]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
dd, # gradient of decay [NV, B, H, L, D_head_K]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# first reverse
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
m_s = o_i[:, None] <= o_i[None, :]
for i in range(1, tl.cdiv(T, BT) + 1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
# [DK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, DV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
# [BT, BT]
b_s = tl.dot(b_k, b_q, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
# [BT, DK]
b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
# [BT, DV]
b_dv = tl.dot(b_s, b_do, allow_tf32=False)
b_d = tl.load(p_d, boundary_check=(0, 1))
if CHECK and i == 1:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
else:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# sync threads
b_h = None
tl.debug_barrier()
m_s = o_i[:, None] >= o_i[None, :]
# [BV, BK]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
NT = tl.cdiv(T, BT)
for i in range(0, NT):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [DV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, DV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0)
# [BT, DK]
b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
# [DV, DK]
if CHECK and i == 0:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
else:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
if i < (NT - 1):
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
((i+1) * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))
def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BT = BT
# ctx.BT = BT
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
assert NK == 1, 'NK should be 1'
o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
else:
final_state = None
CHECK = True
# if version.parse(triton.__version__) < version.parse('2.2.0'):
# import warnings
# warnings.warn(
# "Triton<2.2.0 detected for running this kernel, "
# "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
# "that lead to significant precision loss. "
# "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
# "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
# )
# CHECK = True
grid = (NV, NK, batch_size * n_heads)
v_new = torch.empty_like(v)
fused_chunk_delta_rule_fwd_kernel[grid](
q, k, v, v_new, d, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
CHECK=CHECK,
)
return o, v_new, CHECK, final_state
def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
assert NK == 1
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_delta_rule_bwd_kernel[grid](
q, k, v, d, do, dq, dk, dv, dd, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
CHECK=CHECK,
# num_warps=num_warps,
# num_stages=num_stages
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
dd = dd.sum(0)
dd[:, :, 0:BT] = 0
return dq, dk, dv, dd
class FusedChunkDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0):
# lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory.
assert checkpoint_level in [0, 1]
k_origin = k
# k = _l2_norm_fwd(k_origin)
k = k
d, v_new = fwd_prepare_wy_repr(k, v, beta, BT)
o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state)
if checkpoint_level == 1:
d, v_new = None, None
ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state)
ctx.CHECK = CHECK
ctx.chunk_size = BT
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_final_state=None):
q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors
chunk_size = ctx.chunk_size
k = k_origin
# k = _l2_norm_fwd(k_origin)
if d is None:
d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state)
dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size)
dk.add_(dk2)
# dk = _l2_norm_bwd(k_origin, dk)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None
def fused_chunk_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
BT: int,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
return o, final_state
def delta_rule_recurrence(q, k, v, beta):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)
k = torch.nn.functional.normalize(k, p=2, dim=-1)
for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
return o
if __name__ == "__main__":
import torch.nn.functional as F
seq_len = 128
b = 2
h = 4
q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1)
beta = torch.rand(b, h, seq_len).sigmoid()
q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta))
do = torch.rand_like(v)
o2 = delta_rule_recurrence(q, k, v.clone(), beta)
o2.backward(do, retain_graph=True)
q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad
q.grad = k.grad = v.grad = beta.grad = None
o, _ = fused_chunk_delta_rule(q, k, v, beta, 32)
o.backward(do, retain_graph=True)
q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad
q.grad = k.grad = v.grad = beta.grad = None
print((o - o2).abs().max())
print((q_grad - q_grad2).abs().max())
print((k_grad - k_grad2).abs().max())
print((v_grad - v_grad2).abs().max())
print((beta_grad - beta_grad2).abs().max())

View File

@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
def delta_rule_recurrence(q, k, v, beta):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)
for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
return o
def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
q = q * (d_k ** -0.5)
v = v * beta[..., None]
k_beta = k * beta[..., None]
assert l % chunk_size == 0
# note that diagonal is masked.
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
for i in range(1, chunk_size):
attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
# u
k_cumsum = attn @ v
# w
k_cumdecay = attn @ k_beta
v = k_cumsum
S = k.new_zeros(b, h, d_k, d_v)
o = torch.zeros_like(v)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
for i in range(0, l // chunk_size):
q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
v_prime = k_cumdecay[:, :, i] @ S
v_new = v_i - v_prime
o_inter = q_i @ S
o[:, :, i] = o_inter + attn @ v_new
# chunk state update
S = S + k_i.transpose(-1, -2) @ v_new
return rearrange(o, 'b h n c d -> b h (n c) d')
if __name__ == '__main__':
B = 2
H = 4
L = 256
DK = 128
DV = 128
q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True)
k = (torch.randn(B, H, L, DK)).cuda()
k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True)
v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True)
beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True)
o = delta_rule_recurrence(q, k, v, beta)
do = torch.randn(B, H, L, DV).cuda()
o.backward(do, retain_graph=True)
q_grad, q.grad = q.grad, None
k_grad, k.grad = k.grad, None
v_grad, v.grad = v.grad, None
beta_grad, beta.grad = beta.grad, None
o2 = delta_rule_chunkwise(q, k, v, beta)
o2.backward(do)
assert torch.allclose(o, o2, atol=1e-4), breakpoint()
assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint()
assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint()
assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint()
assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint()
print("All passed!")

View File

@@ -0,0 +1,312 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.jit
def fused_recurrent_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V].
beta, # beta [B, H, L]
o, # output [B, H, L, D_head_V]
initial_state,
final_state, # final hidden state [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_beta = beta + i_bh * T
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
mask_kv = mask_bk[None, :] & mask_bv[:, None]
h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_v_minus = tl.sum(h * _k[None, :], axis=1)
_v -= _v_minus
_beta = tl.load(p_beta).to(tl.float32)
# in-place overwrite
tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv)
_v *= _beta
h += _k[None, :] * _v[:, None]
_o = h * _q[None, :]
_o = tl.sum(_o, axis=1)
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += DK
p_k += DK
p_o += DV
p_v += DV
p_beta += 1
if STORE_FINAL_STATE:
p_final_s = final_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
beta, # beta [B, H, L]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
dbeta, # gradient of beta [B, H, L]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_beta = beta + i_bh * T + T - 1
p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
BK + tl.arange(0, BK) + (T - 1) * DK
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
BV + tl.arange(0, BV) + (T - 1) * DV
d_h = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_beta = tl.load(p_beta).to(tl.float32)
d_h += _q[:, None] * _do[None, :]
d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1)
d_v = tl.sum(d_h * _k[:, None], axis=0)
d_beta = tl.sum(d_v * _v)
d_v = d_v * _beta
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))
d_h -= _k[:, None] * d_v[None, :]
p_do -= DV
p_q -= DK
p_k -= DK
p_v -= DV
p_dk -= DK
p_dv -= DV
p_dbeta -= 1
p_beta -= 1
tl.debug_barrier()
h = tl.zeros([BK, BV], dtype=tl.float32)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_beta = beta + i_bh * T
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK
if USE_INITIAL_STATE:
mask_kv = mask_bk[:, None] & mask_bv[None, :]
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[:, None]) * \
DV + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for i in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_beta = tl.load(p_beta).to(tl.float32)
_v *= _beta
h += _k[:, None] * _v[None, :]
_d_q = h * _do[None, :]
d_q = tl.sum(_d_q, axis=1) * scale
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
if i < T - 1:
d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)
d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)
d_k -= tl.sum(d_v[None, :] * h, axis=1)
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
p_k += DK
p_do += DV
p_v += DV
p_dk += DK
p_dv += DV
p_dq += DK
p_beta += 1
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
assert NK == 1, "NK > 1 is not supported yet"
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
else:
final_state = None
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_fwd_kernel[grid](
q, k, v, beta, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, beta, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, d_final_state=None):
q, k, v, beta, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 1
num_warps = 2
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
dbeta = q.new_empty(NV, batch_size, n_heads, seq_len)
fused_recurrent_bwd_kernel[grid](
q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
dbeta = dbeta.sum(0)
return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None
def fused_recurrent_linear_attn_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
if beta is None:
beta = torch.ones_like(q[..., 0])
o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state)
return o, final_state

View File

@@ -0,0 +1,297 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
from fla.ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
# o: cumprod
# o2: cumprodsum
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_prepare_wy_repr_kernel(
k,
v,
beta,
o,
o2,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
mask_bt = (tl.arange(0, BT) + i_t * BT) < T
mask_bk = tl.arange(0, BK) < K
mask_bv = tl.arange(0, BV) < V
mask_bk = mask_bk[None, :] & mask_bt[:, None]
mask_bv = mask_bv[None, :] & mask_bt[:, None]
# [BT, BK]
b_k = tl.load(p_k, mask=mask_bk, other=0)
# [BT,]
b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)
# [BT, BV]
b_v = tl.load(p_v, mask=mask_bv, other=0)
b_v = (b_v * b_beta[:, None]).to(b_v.dtype)
# [BT, BK]
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
# [BT, BT]
b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
for i in range(BT):
mask = tl.arange(0, BT) == i
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
b_A = tl.where(mask[:, None], b_a, b_A)
b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
b_A = b_A.to(b_k.dtype)
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
b_u = tl.dot(b_A, b_v, allow_tf32=False)
p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)
p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def bwd_prepare_wy_repr_kernel(
k, v, beta,
o, o2, do, do2,
dk, dv, dbeta,
NT, K, V, T,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
mask_bt = (tl.arange(0, BT) + i_t * BT) < T
mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]
mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]
b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)
b_beta = b_beta.to(tl.float32)
A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]
A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)
b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)
b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)
dA = tl.zeros([BT, BT], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
for i in range(BT-1, -1, -1):
mask = tl.arange(0, BT) == i
attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)
do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)
dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)
b_do = b_do - attn[:, None] * do_[None, :]
b_dv = b_dv - attn[:, None] * dv_[None, :]
tl.debug_barrier()
p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
b_v = tl.load(p_v, mask=mask_bv)
b_dk += b_do * b_beta[:, None]
b_dbeta = tl.sum(b_do * b_k, axis=1)
b_dbeta += tl.sum(b_dv * b_v, axis=1)
b_v = None
p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
b_o = tl.load(p_o, mask=mask_bk)
b_o2 = tl.load(p_o2, mask=mask_bv)
dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)
dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),
allow_tf32=False)
dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)
b_dv *= b_beta[:, None]
p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)
dA = dA * b_beta[:, None]
b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)
b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)
p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)
def fwd_prepare_wy_repr(k, v, beta, chunk_size):
B, H, T, K, V = *k.shape, v.shape[-1]
v_new = torch.empty_like(v)
o_cumdecay = torch.empty_like(k)
BT = chunk_size
NT = triton.cdiv(T, BT)
BK = triton.next_power_of_2(K)
BV = triton.next_power_of_2(V)
fwd_prepare_wy_repr_kernel[(NT, B*H)](
k, v, beta, o_cumdecay, v_new,
T, K, V, BT, BK, BV
)
return o_cumdecay, v_new
def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):
b, h, l, d_k = do.shape
d_v = v.shape[-1]
BK = triton.next_power_of_2(d_k)
BV = triton.next_power_of_2(d_v)
c = chunk_size
BK = d_k
NT = triton.cdiv(l, c)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dbeta = torch.zeros_like(beta)
bwd_prepare_wy_repr_kernel[(NT, b*h)](
k, v, beta,
o_cumdecay, v_new, do, do2,
dk, dv, dbeta,
NT, d_k, d_v, l, chunk_size, BK, BV
)
return dk, dv, dbeta
class WYRepresentationPrepration(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, k, v, beta, chunk_size):
o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
ctx.chunk_size = chunk_size
ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new)
return o_cumdecay, v_new
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, do2):
k, v, beta, o_cumdecay, v_new = ctx.saved_tensors
dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size)
return dk, dv, dbeta, None
prepare_wy_repr = WYRepresentationPrepration.apply
def naive(k, v, beta, chunk_size):
l_org = k.shape[2]
l_new = triton.next_power_of_2(l_org)
# pad k, v, beta
k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
# k = torch.nn.functional.normalize(k, dim=-1, p=2)
beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
k_beta = k * beta[..., None]
v = v * beta[..., None]
attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
attn = attn * beta[..., None]
x = attn @ v
o = torch.zeros_like(k)
o2 = torch.zeros_like(v)
o[..., 0, :] = k_beta[..., 0, :].clone()
o2[..., 0, :] = x[..., 0, :].clone()
for i in range(1, chunk_size):
o_i = (o[..., :i, :]).clone()
o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
o2_i = (o2[..., :i, :]).clone()
o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
seq_len = 2048
b = 4
h = 8
k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2)
v = torch.randn(b, h, seq_len, 256)
beta = torch.rand(b, h, seq_len).sigmoid()
require_grad = True
k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
do = torch.rand_like(k)
do2 = torch.rand_like(v)
print("Start warmup.")
o1, o2 = prepare_wy_repr(k, v, beta, 32)
# (o1 * do + o2 * do2).sum().backward()
o3, o4 = prepare_wy_repr2(k, v, beta, 32)
# (o1 * do + o2 * do2).sum().backward()
print((o1 - o3).abs().max())
print((o2 - o4).abs().max())
for i in range(30):
o1, o2 = prepare_wy_repr(k, v, beta, 32)
(o1 * do + o2 * do2).sum().backward()
o1, o2 = prepare_wy_repr2(k, v, beta, 32)
(o1 * do + o2 * do2).sum().backward()
print("Done warmup.")
import time
torch.cuda.synchronize()
start = time.time()
for i in range(200):
o1, o2 = prepare_wy_repr(k, v, beta, 64)
(o1 * do + o2 * do2).sum().backward()
torch.cuda.synchronize()
print(time.time() - start)
torch.cuda.synchronize()
start = time.time()
for i in range(200):
o1, o2 = prepare_wy_repr2(k, v, beta, 64)
(o1 * do + o2 * do2).sum().backward()
torch.cuda.synchronize()
print(time.time() - start)

View File

@@ -0,0 +1,401 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
# o: cumprod
# o2: cumprodsum
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_prepare_wy_repr_kernel(
k,
v,
beta,
w,
u,
A,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
for i in range(1, BT):
mask = tl.arange(0, BT) == i
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
b_A = tl.where(mask[:, None], b_a, b_A)
b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))
b_A = b_A.to(k.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_recompute_w_u_kernel(
k,
v,
beta,
w,
u,
A,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def bwd_prepare_wy_repr_kernel(
k, v, beta, A,
dw, du,
dk, dv, dbeta,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
b_dbeta = tl.zeros([BT], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
b_du = tl.load(p_du, boundary_check=(0, 1))
b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)
b_dv = b_dv_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
# store
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
b_A2 = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_dw = tl.load(p_dw, boundary_check=(0, 1))
b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)
b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)
b_dk = b_dk_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
# store
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])
b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
tl.debug_barrier()
for i in range(BT-1, 0, -1):
mask = tl.arange(0, BT) == i
b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0)
b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1)
b_dA = tl.where(mask[:, None], b_da2, b_dA)
b_dA += b_da[None, :] * b_a[:, None]
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
tl.debug_barrier()
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dk = tl.load(p_dk, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
b_dk += b_dk_beta * b_beta[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty),boundary_check=(0,))
def fwd_prepare_wy_repr(k, v, beta, BT):
B, H, T, K, V = *k.shape, v.shape[-1]
u = torch.empty_like(v)
w = torch.empty_like(k)
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)
fwd_prepare_wy_repr_kernel[(NT, B*H)](
k, v, beta, w, u, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
)
return w, u, A
def fwd_recompute_w_u(k, v, beta, A, BT):
B, H, T, K, V = *k.shape, v.shape[-1]
u = torch.empty_like(v)
w = torch.empty_like(k)
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
fwd_recompute_w_u_kernel[(NT, B*H)](
k, v, beta, w, u, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
)
return w, u
def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):
B, H, T, K, V = *k.shape, v.shape[-1]
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
NT = triton.cdiv(T, BT)
dk = torch.empty_like(k)
dv = torch.empty_like(v).contiguous()
dbeta = torch.zeros_like(beta)
bwd_prepare_wy_repr_kernel[(NT, B*H)](
k, v, beta, A,
dw, du,
dk, dv, dbeta,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
)
return dk, dv, dbeta
class WYRepresentationPrepration(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, k, v, beta, chunk_size):
ctx.BT = chunk_size
w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)
ctx.save_for_backward(k, v, beta, A)
return w, u
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, dw, du):
k, v, beta, A = ctx.saved_tensors
BT = ctx.BT
dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT)
return dk, dv, dbeta, None
prepare_wy_repr = WYRepresentationPrepration.apply
def naive(k, v, beta, chunk_size):
l_org = k.shape[2]
l_new = triton.next_power_of_2(l_org)
# pad k, v, beta
k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
# k = torch.nn.functional.normalize(k, dim=-1, p=2)
beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
k_beta = k * beta[..., None]
v = v * beta[..., None]
attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
attn = attn * beta[..., None]
x = attn @ v
o = torch.zeros_like(k)
o2 = torch.zeros_like(v)
o[..., 0, :] = k_beta[..., 0, :].clone()
o2[..., 0, :] = x[..., 0, :].clone()
for i in range(1, chunk_size):
o_i = (o[..., :i, :]).clone()
o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
o2_i = (o2[..., :i, :]).clone()
o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
if __name__ == "__main__":
torch.set_default_dtype(torch.float32)
seq_len = 1024
b = 4
h = 4
k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)
v = torch.randn(b, h, seq_len, 128)
beta = torch.rand(b, h, seq_len).sigmoid()
# beta = torch.ones(b, h, seq_len)
require_grad = True
k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
do = torch.rand_like(k)
do2 = torch.rand_like(v)
o1, o2 = naive(k.clone(), v.clone(), beta.clone(), 64)
if require_grad:
o1.backward(do, retain_graph=True)
o2.backward(do2, retain_graph=True)
k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad
k.grad = v.grad = beta.grad = None
o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone())
print((o1-o3).abs().max())
print((o2-o4).abs().max())
if require_grad:
o3.backward(do, retain_graph=True)
o4.backward(do2, retain_graph=True)
k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad
print((k_grad2-k_grad).abs().max())
print((v_grad2-v_grad).abs().max())
print((beta_grad2-beta_grad).abs().max())
breakpoint()