This commit is contained in:
544
finetune/lora/v6/fla/ops/delta_rule/chunk.py
vendored
Normal file
544
finetune/lora/v6/fla/ops/delta_rule/chunk.py
vendored
Normal file
@@ -0,0 +1,544 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from fla.ops.utils import contiguous
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from fla.ops.delta_rule.wy_fast import fwd_recompute_w_u, fwd_prepare_wy_repr, bwd_prepare_wy_repr
|
||||
from fla.ops.delta_rule.chunk_fuse import fused_chunk_delta_rule_fwd, fused_chunk_delta_rule_bwd
|
||||
# from fla.ops.delta_rule.utils import bwd_prepare_wy_repr
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def fwd_prepare_dv_kernel(
|
||||
q,
|
||||
k,
|
||||
do,
|
||||
dv,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
T,
|
||||
K,
|
||||
V,
|
||||
scale,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_k.dtype)
|
||||
b_A += tl.dot(b_k, b_q, allow_tf32=False)
|
||||
|
||||
b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A , 0).to(do.dtype.element_ty)
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_dv = tl.dot(b_A, b_do, allow_tf32=False)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def fwd_prepare_dv(q, k, do, BT):
|
||||
dv = torch.empty_like(do)
|
||||
B, H, T, K, V = *k.shape, do.shape[-1]
|
||||
NT = triton.cdiv(T, BT)
|
||||
BK = min(triton.next_power_of_2(K), 64)
|
||||
BV = min(triton.next_power_of_2(V), 64)
|
||||
fwd_prepare_dv_kernel[(NT, B*H)](
|
||||
q, k, do, dv,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
do.stride(1), do.stride(2), do.stride(3),
|
||||
T, K, V, K**-0.5, BT, BK, BV
|
||||
)
|
||||
return dv
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_delta_rule_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
d,
|
||||
v_new,
|
||||
h,
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i_t in range(NT):
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
# since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
|
||||
for i_c in range(tl.cdiv(BT, BC)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
||||
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_d = tl.load(p_d, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)
|
||||
# [BK, BV]
|
||||
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
|
||||
b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
|
||||
b_h += b_h_cumsum
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_linear_attn_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
o,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_s = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
b_o += tl.dot(b_q, b_h, allow_tf32=False)
|
||||
b_s += tl.dot(b_q, b_k, allow_tf32=False)
|
||||
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_delta_rule_bwd_kernel_dhu(
|
||||
q,
|
||||
k,
|
||||
d,
|
||||
do,
|
||||
dh,
|
||||
dv,
|
||||
dv2,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
||||
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_d = tl.load(p_d, boundary_check=(0, 1))
|
||||
# [BT, V]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BT]
|
||||
# b_s = tl.dot(b_k, b_q, allow_tf32=False)
|
||||
# b_s = tl.where(m_s, b_s, 0)
|
||||
# b_dv = tl.dot(b_s.to(b_do.dtype), b_do, allow_tf32=False) + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
|
||||
|
||||
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
||||
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
|
||||
p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
|
||||
b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False)
|
||||
b_dh += b_dh_tmp
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_delta_rule_bwd_kernel_dqkw(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
h,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
dw,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
n_bh = tl.num_programs(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
|
||||
b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dw = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
|
||||
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
||||
b_dw += tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
|
||||
|
||||
# [BT, BT]
|
||||
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
|
||||
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
|
||||
def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state):
|
||||
B, H, T, K, V = *k.shape, u.shape[-1]
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
assert BK <= 256, "current kernel does not support head dimension larger than 256."
|
||||
BV = 16 if BK > 128 else 32
|
||||
BV = 64 if BK <= 64 else BV
|
||||
BC = 16 if BK > 128 else 32
|
||||
BC = 64 if BK <= 64 else BC
|
||||
BC = min(BT, BC)
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
|
||||
|
||||
h = k.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
v_new = torch.empty_like(u)
|
||||
chunk_delta_rule_fwd_kernel_h[grid](
|
||||
k, u, w, v_new, h, initial_state, final_state,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
u.stride(1), u.stride(2), u.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None,
|
||||
)
|
||||
return h, v_new
|
||||
|
||||
|
||||
def chunk_bwd_dhu_fn(q, k, w, do, dv, BT):
|
||||
B, H, T, K, V = *q.shape, do.shape[-1]
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
assert BK <= 256, "current kernel does not support head dimension being larger than 256."
|
||||
BV = 16 if BK > 128 else 32
|
||||
BV = 64 if BK <= 64 else BV
|
||||
BC = 16 if BK > 128 else 32
|
||||
BC = 64 if BK <= 64 else BC
|
||||
BC = min(BT, BC)
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
|
||||
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
# dv_new = torch.empty_like(do)
|
||||
grid = (NK, NV, B * H)
|
||||
dv2 = torch.empty_like(dv)
|
||||
chunk_delta_rule_bwd_kernel_dhu[grid](
|
||||
q, k, w, do, dh, dv, dv2,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
do.stride(1), do.stride(2), do.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
K**-0.5,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
|
||||
)
|
||||
return dh, dv2
|
||||
|
||||
|
||||
def chunk_fwd_o_fn(q, k, v_new, h, BT):
|
||||
B, H, T, K, V = *q.shape, v_new.shape[-1]
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
o = torch.empty_like(v_new)
|
||||
BK = min(triton.next_power_of_2(K), 64)
|
||||
BV = min(triton.next_power_of_2(K), 64)
|
||||
NV = triton.cdiv(V, BV)
|
||||
NT = triton.cdiv(T, BT)
|
||||
grid = (NV, NT, B * H)
|
||||
chunk_linear_attn_fwd_kernel_o[grid](
|
||||
q, k, v_new, h, o,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v_new.stride(1), v_new.stride(2), v_new.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
scale=K**-0.5,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
|
||||
def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT):
|
||||
B, H, T, K, V = *q.shape, v_new.shape[-1]
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
BK = min(triton.next_power_of_2(K), 64)
|
||||
BV = min(triton.next_power_of_2(V), 64)
|
||||
NV = triton.cdiv(V, BV)
|
||||
NT = triton.cdiv(T, BT)
|
||||
grid = (NV, NT, B * H)
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dw = torch.empty_like(w)
|
||||
chunk_delta_rule_bwd_kernel_dqkw[grid](
|
||||
q, k, v_new, w, h, do, dh, dq, dk, du, dw,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v_new.stride(1), v_new.stride(2), v_new.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale = K ** -0.5,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype)
|
||||
|
||||
|
||||
class ChunkDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):
|
||||
### obtain WY representation. u is actually the new v.
|
||||
w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
|
||||
# ### forward_h
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
|
||||
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)
|
||||
## obtain output
|
||||
o = chunk_fwd_o_fn(q, k, v_new, h, BT)
|
||||
# save memory
|
||||
if checkpoint_level == 1:
|
||||
h, v_new = None, None
|
||||
ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state)
|
||||
ctx.BT = BT
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, d_ht=None):
|
||||
q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
|
||||
scale = q.shape[-1] ** -0.5
|
||||
BT = ctx.BT
|
||||
w, u = fwd_recompute_w_u(k, v, beta, A, BT)
|
||||
# checkpont_level=1, recomputation.
|
||||
if h is None:
|
||||
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None)
|
||||
dv = fwd_prepare_dv(q, k, do, BT)
|
||||
dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)
|
||||
dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)
|
||||
dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT)
|
||||
dk.add_(dk2)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None
|
||||
|
||||
def chunk_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
BT: int,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
):
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
Reference in New Issue
Block a user