This commit is contained in:
josc146
2024-05-28 22:35:47 +08:00
parent 3488d22d22
commit f05a4acb04
138 changed files with 29047 additions and 334 deletions

View File

@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_linear_attn
from .chunk_fuse import fused_chunk_linear_attn
from .recurrent_fuse import fused_recurrent_linear_attn
__all__ = [
'chunk_linear_attn',
'fused_chunk_linear_attn',
'fused_recurrent_linear_attn'
]

View File

@@ -0,0 +1,359 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
@torch.jit.script
def normalize_output(q, k, o):
k = k.transpose(-2, -1)
k = k.cumsum(-1)
k = k.transpose(-2, -1)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-5)
@triton.jit
def chunk_linear_attn_fwd_kernel_h(
k,
v,
h,
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BK, BV]
b_h += tl.dot(b_k, b_v, allow_tf32=False)
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_linear_attn_fwd_kernel_o(
q,
k,
v,
h,
o,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_s = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_o += tl.dot(b_q, b_h, allow_tf32=False)
b_s += tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_linear_attn_bwd_kernel_dh(
q,
do,
dh,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, V]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BV]
b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
@triton.jit
def chunk_linear_attn_bwd_kernel_dqkv(
q,
k,
v,
h,
do,
dh,
dq,
dk,
dv,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
o_i = tl.arange(0, BT)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BT]
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
# [BT, BK]
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
class ChunkLinearAttentionFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@contiguous
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
ctx.scale = scale
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
h = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_linear_attn_fwd_kernel_h[grid](
k, v, h, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NV, NT, B * H)
o = torch.empty_like(v)
chunk_linear_attn_fwd_kernel_o[grid](
q, k, v, h, o,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
ctx.save_for_backward(q, k, v, h)
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_ht=None):
q, k, v, h = ctx.saved_tensors
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
scale = ctx.scale
dh = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_linear_attn_bwd_kernel_dh[grid](
q, do, dh,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NK, NT, B * H)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = v.new_empty(NK, *v.shape)
num_stages = 1
num_warps = 4 if BK == 64 else 2
chunk_linear_attn_bwd_kernel_dqkv[grid](
q, k, v, h, do, dh, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
dv = dv.sum(0)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
def chunk_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
if scale == -1:
scale = q.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
if normalize:
o = normalize_output(q * scale, k, o)
return o, final_state

View File

@@ -0,0 +1,326 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@torch.jit.script
def normalize_output(q, k, o):
k = k.transpose(-2, -1)
k = k.cumsum(-1)
k = k.transpose(-2, -1)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-5)
@triton.jit
def fused_chunk_linear_attn_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# [BT, BT]
m_s = o_i[:, None] >= o_i[None, :]
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i in range(0, tl.cdiv(T, BT)):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_k.dtype)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
# [BT, BV]
b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
if CHECK and i == 0:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
else:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
if STORE_FINAL_STATE:
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_chunk_linear_attn_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
# [BV, BK]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [DV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, DV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0)
# [BT, DK]
b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
# [DV, DK]
if CHECK and i == 0:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
else:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# sync threads
b_h = None
tl.debug_barrier()
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
m_s = o_i[:, None] <= o_i[None, :]
for i in range(1, tl.cdiv(T, BT) + 1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
# [DK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, DV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# b_dd = (b_do]).to(b_do.dtype)
# [BT, BT]
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
# [BT, BT]
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
# [BT, DK]
b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
# [BT, DV]
b_dv = tl.dot(b_s, b_do, allow_tf32=False)
if CHECK and i == 1:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
else:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
class FusedChunkLinearAttentionFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
ctx.scale = scale
BT = 64
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 4
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
else:
final_state = None
# the bug still exists even for Triton 2.2 on H100 GPUs
# so we always enable initial checks
CHECK = True
if version.parse(triton.__version__) < version.parse('2.2.0'):
import warnings
warnings.warn(
"Triton<2.2.0 detected for running this kernel, "
"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
"that lead to significant precision loss. "
"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
)
CHECK = True
grid = (NV, NK, batch_size * n_heads)
fused_chunk_linear_attn_fwd_kernel[grid](
q, k, v, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
CHECK=CHECK,
num_warps=num_warps,
num_stages=num_stages
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, initial_state)
ctx.CHECK = CHECK
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_final_state=None):
q, k, v, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = ctx.scale
BT = 64
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 4
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_linear_attn_bwd_kernel[grid](
q, k, v, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
CHECK=ctx.CHECK,
num_warps=num_warps,
num_stages=num_stages
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
def fused_chunk_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
if scale == -1:
scale = q.shape[-1] ** -0.5
o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
if normalize:
o = normalize_output(q * scale, k, o)
return o, final_state

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
def torch_chunk_linear_attn(q, k, v, chunk_size=64):
q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] **-0.5)
k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
kv = k.transpose(-1, -2) @ v
kv = kv.cumsum(2)
kv = torch.cat([
torch.zeros_like(kv[:, :, :1]),
kv[:, :, :-1]
], dim=2)
inter = q @ kv
intra = ((q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)) @ v
o = inter + intra
return rearrange(o, 'b h n c d -> b h (n c) d')

View File

@@ -0,0 +1,284 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@torch.jit.script
def normalize_output(q, k, o):
k = k.transpose(-2, -1)
k = k.cumsum(-1)
k = k.transpose(-2, -1)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-5)
@triton.jit
def fused_recurrent_linear_attn_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
initial_state,
final_state, # final hidden state [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
mask_kv = mask_bk[None, :] & mask_bv[:, None]
h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
h += _k[None, :] * _v[:, None]
_o = h * _q[None, :]
_o = tl.sum(_o, axis=1)
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += DK
p_k += DK
p_o += DV
p_v += DV
if STORE_FINAL_STATE:
p_final_s = final_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_linear_attn_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
mask_kv = mask_bk[:, None] & mask_bv[None, :]
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[:, None]) * \
DV + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for i in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
h += _k[:, None] * _v[None, :]
_d_q = h * _do[None, :]
d_q = tl.sum(_d_q, axis=1) * scale
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
p_k += DK
p_do += DV
p_v += DV
p_dq += DK
# sync threads
tl.debug_barrier()
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
BK + tl.arange(0, BK) + (T - 1) * DK
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
BV + tl.arange(0, BV) + (T - 1) * DV
d_h = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
d_h += _q[:, None] * _do[None, :]
d_k = tl.sum(d_h * _v[None, :], axis=1)
d_v = tl.sum(d_h * _k[:, None], axis=0)
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
p_do -= DV
p_q -= DK
p_k -= DK
p_v -= DV
p_dk -= DK
p_dv -= DV
class FusedRecurrentLinearAttentionFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, q, k, v, initial_state=None, output_final_state=False):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
else:
final_state = None
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_linear_attn_fwd_kernel[grid](
q, k, v, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, d_final_state=None):
q, k, v, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_linear_attn_bwd_kernel[grid](
q, k, v, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
return dq, dk, dv, None, None
def fused_recurrent_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedRecurrentLinearAttentionFunction.apply(
q, k, v, initial_state, output_final_state)
if normalize:
o = normalize_output(q, k, o)
return o, final_state