282 lines
9.6 KiB
Python
282 lines
9.6 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||
|
|
||
|
from typing import Tuple
|
||
|
|
||
|
import torch
|
||
|
import triton
|
||
|
import triton.language as tl
|
||
|
|
||
|
from fla.utils import contiguous
|
||
|
|
||
|
# on-the-fly computation without materializing hidden statets into HBMs
|
||
|
|
||
|
|
||
|
@triton.jit
|
||
|
def fused_recurrent_retention_fwd_kernel(
|
||
|
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||
|
q, # query [B, H, L, D_head_K]
|
||
|
k, # key [B, H, L, D_head_V]
|
||
|
v, # value [B, H, L, D_head_V]
|
||
|
o, # output [B, H, L, D_head_V]
|
||
|
initial_state,
|
||
|
final_state, # final hidden state [B, H, D_head_K, D_head_V]
|
||
|
|
||
|
s_qk_h, # stride size: L * D_head_K
|
||
|
s_qk_t, # stride size: D_head_K
|
||
|
s_qk_d, # stride size: 1
|
||
|
|
||
|
s_vo_h, # stride size: L * D_head_V
|
||
|
s_vo_t, # stride size: D_head_V
|
||
|
s_vo_d, # stride size: 1
|
||
|
|
||
|
B, # batch size
|
||
|
H, # n_heads
|
||
|
T, # seq_len
|
||
|
scale, # D_head_K ** -0.5
|
||
|
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||
|
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||
|
DK: tl.constexpr, # D_head_K
|
||
|
DV: tl.constexpr, # D_head_V
|
||
|
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||
|
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
||
|
):
|
||
|
# indices
|
||
|
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||
|
i_h = i_bh % H
|
||
|
|
||
|
# decay rate given the head index
|
||
|
b_b = (1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||
|
|
||
|
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||
|
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||
|
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||
|
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||
|
|
||
|
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
|
||
|
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
|
||
|
mask_kv = mask_bk[None, :] & mask_bv[:, None]
|
||
|
|
||
|
h = tl.zeros([BV, BK], dtype=tl.float32)
|
||
|
|
||
|
if USE_INITIAL_STATE:
|
||
|
p_init_s = initial_state + i_bh * DK * DV + \
|
||
|
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||
|
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||
|
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||
|
|
||
|
for _ in range(0, T):
|
||
|
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||
|
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||
|
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||
|
|
||
|
h = b_b * h + _k[None, :] * _v[:, None]
|
||
|
_o = h * _q[None, :]
|
||
|
_o = tl.sum(_o, axis=1)
|
||
|
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
|
||
|
|
||
|
p_q += DK
|
||
|
p_k += DK
|
||
|
p_o += DV
|
||
|
p_v += DV
|
||
|
|
||
|
if STORE_FINAL_STATE:
|
||
|
p_final_s = final_state + i_bh * DK * DV + \
|
||
|
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||
|
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||
|
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
|
||
|
|
||
|
|
||
|
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||
|
@triton.jit
|
||
|
def fused_recurrent_retention_bwd_kernel(
|
||
|
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||
|
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||
|
q, # query [B, H, L, D_head_K]
|
||
|
k, # key [B, H, L, D_head_V]
|
||
|
v, # value [B, H, L, D_head_V]
|
||
|
|
||
|
do, # gradient of output [B, H, L, D_head_V]
|
||
|
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||
|
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||
|
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||
|
|
||
|
# initial hidden state initialization [B, H, D_head_K, D_head_V]
|
||
|
initial_state,
|
||
|
|
||
|
s_qk_h, # stride size: L * D_head_K
|
||
|
s_qk_t, # stride size: D_head_K
|
||
|
s_qk_d, # stride size: 1
|
||
|
|
||
|
s_vo_h, # stride size: L * D_head_V
|
||
|
s_vo_t, # stride size: D_head_V
|
||
|
s_vo_d, # stride size: 1
|
||
|
|
||
|
B, # batch_size
|
||
|
H, # n_heads
|
||
|
T, # seq_len
|
||
|
scale, # D_head_K ** -0.5
|
||
|
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||
|
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||
|
DK: tl.constexpr, # D_head_K
|
||
|
DV: tl.constexpr, # D_head_V
|
||
|
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||
|
):
|
||
|
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||
|
i_h = i_bh % H
|
||
|
|
||
|
b_b = 1 - tl.math.pow(2, -5 - i_h * 1.0)
|
||
|
|
||
|
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||
|
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||
|
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||
|
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||
|
|
||
|
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||
|
mask_bk = i_k * BK + tl.arange(0, BK) < DK
|
||
|
mask_bv = i_v * BV + tl.arange(0, BV) < DV
|
||
|
|
||
|
h = tl.zeros([BK, BV], dtype=tl.float32)
|
||
|
|
||
|
if USE_INITIAL_STATE:
|
||
|
mask_kv = mask_bk[:, None] & mask_bv[None, :]
|
||
|
p_init_s = initial_state + i_bh * DK * DV + \
|
||
|
(i_k * BK + tl.arange(0, BK)[:, None]) * \
|
||
|
DV + (i_v * BV + tl.arange(0, BV)[None, :])
|
||
|
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||
|
|
||
|
for i in range(0, T):
|
||
|
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||
|
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||
|
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||
|
|
||
|
h = b_b * h + _k[:, None] * _v[None, :]
|
||
|
_d_q = h * _do[None, :]
|
||
|
d_q = tl.sum(_d_q, axis=1) * scale
|
||
|
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
|
||
|
|
||
|
p_k += DK
|
||
|
p_do += DV
|
||
|
p_v += DV
|
||
|
p_dq += DK
|
||
|
|
||
|
# sync threads
|
||
|
tl.debug_barrier()
|
||
|
|
||
|
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
|
||
|
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
|
||
|
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
|
||
|
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
|
||
|
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
|
||
|
BK + tl.arange(0, BK) + (T - 1) * DK
|
||
|
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
|
||
|
BV + tl.arange(0, BV) + (T - 1) * DV
|
||
|
d_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||
|
|
||
|
for _ in range(T):
|
||
|
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||
|
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||
|
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||
|
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||
|
d_h += _q[:, None] * _do[None, :]
|
||
|
d_k = tl.sum(d_h * _v[None, :], axis=1)
|
||
|
d_v = tl.sum(d_h * _k[:, None], axis=0)
|
||
|
|
||
|
d_h *= b_b
|
||
|
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||
|
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||
|
|
||
|
p_do -= DV
|
||
|
p_q -= DK
|
||
|
p_k -= DK
|
||
|
p_v -= DV
|
||
|
p_dk -= DK
|
||
|
p_dv -= DV
|
||
|
|
||
|
|
||
|
class FusedRecurrentRetentionFunction(torch.autograd.Function):
|
||
|
|
||
|
@staticmethod
|
||
|
@contiguous
|
||
|
def forward(ctx, q, k, v, initial_state=None, output_final_state=False):
|
||
|
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||
|
d_head_v = v.shape[-1]
|
||
|
|
||
|
scale = d_head_qk ** -0.5
|
||
|
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||
|
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||
|
num_stages = 1
|
||
|
num_warps = 1
|
||
|
|
||
|
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||
|
|
||
|
if output_final_state:
|
||
|
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
|
||
|
else:
|
||
|
final_state = None
|
||
|
|
||
|
grid = (NV, NK, batch_size * n_heads)
|
||
|
fused_recurrent_retention_fwd_kernel[grid](
|
||
|
q, k, v, o, initial_state, final_state,
|
||
|
q.stride(1), q.stride(2), q.stride(3),
|
||
|
v.stride(1), v.stride(2), v.stride(3),
|
||
|
batch_size, n_heads, seq_len, scale,
|
||
|
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||
|
num_warps=num_warps,
|
||
|
num_stages=num_stages,
|
||
|
USE_INITIAL_STATE=initial_state is not None,
|
||
|
STORE_FINAL_STATE=final_state is not None
|
||
|
)
|
||
|
|
||
|
o = o.sum(0)
|
||
|
ctx.save_for_backward(q, k, v, initial_state)
|
||
|
return o, final_state
|
||
|
|
||
|
@staticmethod
|
||
|
@contiguous
|
||
|
def backward(ctx, do, d_final_state=None):
|
||
|
q, k, v, initial_state = ctx.saved_tensors
|
||
|
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||
|
d_head_v = v.shape[-1]
|
||
|
scale = d_head_qk ** -0.5
|
||
|
|
||
|
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||
|
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||
|
num_stages = 1
|
||
|
num_warps = 1
|
||
|
|
||
|
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||
|
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||
|
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||
|
grid = (NV, NK, batch_size * n_heads)
|
||
|
|
||
|
fused_recurrent_retention_bwd_kernel[grid](
|
||
|
q, k, v, do, dq, dk, dv, initial_state,
|
||
|
q.stride(1), q.stride(2), q.stride(3),
|
||
|
v.stride(1), v.stride(2), v.stride(3),
|
||
|
batch_size, n_heads, seq_len, scale,
|
||
|
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||
|
num_warps=num_warps,
|
||
|
num_stages=num_stages,
|
||
|
USE_INITIAL_STATE=initial_state is not None
|
||
|
)
|
||
|
dq = dq.sum(0)
|
||
|
dk = dk.sum(0)
|
||
|
dv = dv.sum(0)
|
||
|
return dq, dk, dv, None, None
|
||
|
|
||
|
|
||
|
# fused_recurrent_retention = FusedRecurrentRetentionFunction.apply
|
||
|
|
||
|
def fused_recurrent_retention(
|
||
|
q: torch.Tensor,
|
||
|
k: torch.Tensor,
|
||
|
v: torch.Tensor,
|
||
|
initial_state: torch.Tensor = None,
|
||
|
output_final_state: bool = False
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
if initial_state is not None:
|
||
|
initial_state = initial_state.detach()
|
||
|
o, final_state = FusedRecurrentRetentionFunction.apply(q, k, v, initial_state, output_final_state)
|
||
|
return o, final_state
|