This commit is contained in:
josc146
2024-05-28 22:35:47 +08:00
parent 3488d22d22
commit f05a4acb04
138 changed files with 29047 additions and 334 deletions

18
finetune/lora/v6/fla/ops/__init__.py vendored Normal file
View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
from .based import fused_chunk_based, parallel_based
from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
from .retention import (chunk_retention, fused_chunk_retention,
fused_recurrent_retention, parallel_retention)
__all__ = [
'fused_chunk_based',
'parallel_based',
'chunk_gla',
'fused_chunk_gla',
'fused_recurrent_gla',
'chunk_retention',
'fused_chunk_retention',
'fused_recurrent_retention',
'parallel_retention'
]

View File

@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_abc
from .chunk_gate import chunk_gated_abc
from .recurrent_fuse import fused_recurrent_gated_abc
__all__ = [
'chunk_abc',
'chunk_gated_abc',
'fused_recurrent_gated_abc'
]

1194
finetune/lora/v6/fla/ops/abc/chunk.py vendored Normal file

File diff suppressed because it is too large Load Diff

1287
finetune/lora/v6/fla/ops/abc/chunk_gate.py vendored Normal file

File diff suppressed because it is too large Load Diff

90
finetune/lora/v6/fla/ops/abc/naive.py vendored Normal file
View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
from typing import Optional
import torch
def naive_recurrent_abc(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
s: torch.Tensor,
g: Optional[torch.Tensor] = None,
scale: Optional[int] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: Optional[bool] = False
) -> torch.Tensor:
dtype = q.dtype
# [batch_size, n_heads, seq_len, n_slots]
if g is None:
z = s.float().logcumsumexp(2)
g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
s = torch.exp(s - z)
q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
ok = torch.zeros_like(s)
if scale is None:
scale = q.shape[-1] ** -0.5
final_state = None
if initial_state is not None:
hk += initial_state[0]
for i in range(T):
q_i = q[:, :, i] * scale
k_i = k[:, :, i]
v_i = s[:, :, i]
g_i = g[:, :, i].exp()
hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
qv = ok.softmax(-1)
hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
ov = torch.zeros_like(v)
if initial_state is not None:
hv += initial_state[1]
for i in range(T):
q_i = qv[:, :, i]
k_i = s[:, :, i]
v_i = v[:, :, i]
g_i = g[:, :, i].exp()
hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
if output_final_state:
final_state = (hk, hv)
return ov.to(dtype), final_state
def naive_cumsum_abc(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
s: torch.Tensor
) -> torch.Tensor:
"""
A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper.
This is just for demonstration purposes, with no numerical stabilities guaranteed.
"""
dtype = q.dtype
q, k, v, s = map(lambda x: x.float(), (q, k, v, s))
scale = q.shape[-1] ** -0.5
# [batch_size, n_heads, seq_len, n_slots]
s = (s - s.max(2, True)[0]).exp()
z = s.cumsum(2)
# [batch_size, n_heads, seq_len, n_slots, d_head]
K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
# [batch_size, n_heads, seq_len, n_slots]
p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1)
# [batch_size, n_heads, seq_len, d_head]
o = torch.einsum('...m,...md->...d', p, V)
return o.to(dtype), None

View File

@@ -0,0 +1,388 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Yu Zhang, Songlin Yang
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
@triton.jit
def fused_recurrent_gated_abc_fwd_kernel(
q,
k,
v,
gk,
gv,
o,
h0,
ht,
s_k_h,
s_v_h,
scale,
B: tl.constexpr,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
REVERSE: tl.constexpr,
USE_GK: tl.constexpr,
USE_GV: tl.constexpr,
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
mask_bk = (i_k * BK + tl.arange(0, BK)) < K
mask_bv = (i_v * BV + tl.arange(0, BV)) < V
h = tl.zeros([BV, BK], dtype=tl.float32)
mask_kv = mask_bk[None, :] & mask_bv[:, None]
if USE_INITIAL_STATE:
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
if USE_GK:
b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
h = h * b_gk[None, :]
if USE_GV:
b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
h = h * b_gv[:, None]
h += b_k[None, :] * b_v[:, None]
b_o = h * b_q[None, :]
b_o = tl.sum(b_o, axis=1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += -K if REVERSE else K
p_k += -K if REVERSE else K
p_o += -V if REVERSE else V
p_v += -V if REVERSE else V
if USE_GK:
p_gk += -K if REVERSE else K
if USE_GV:
p_gv += -V if REVERSE else V
if STORE_FINAL_STATE:
p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)
@triton.jit
def fused_recurrent_gated_abc_bwd_kernel(
q,
k,
v,
gk,
gv,
do,
dq,
dk,
dv,
h0,
s_k_h,
s_v_h,
scale,
B: tl.constexpr,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
REVERSE: tl.constexpr,
USE_GK: tl.constexpr,
USE_GV: tl.constexpr,
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
mask_bk = i_k * BK + tl.arange(0, BK) < K
mask_bv = i_v * BV + tl.arange(0, BV) < V
mask_kv = mask_bk[:, None] & mask_bv[None, :]
h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
if USE_GK:
b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
h = h * b_gk[:, None]
if USE_GV:
b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
h = h * b_gv[None, :]
h += b_k[:, None] * b_v[None, :]
b_dq = tl.sum(h * b_do[None, :], axis=1) * scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)
p_k += -K if REVERSE else K
p_v += -V if REVERSE else V
p_q += -K if REVERSE else K
p_do += -V if REVERSE else V
p_dq += -K if REVERSE else K
if USE_GK:
p_gk += -K if REVERSE else K
if USE_GV:
p_gv += -V if REVERSE else V
# sync threads
tl.debug_barrier()
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
b_dh += b_q[:, None] * b_do[None, :]
b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
if USE_GK:
b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
b_dh *= b_gk[:, None]
if USE_GV:
b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
b_dh *= b_gv[None, :]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
p_q += K if REVERSE else -K
p_k += K if REVERSE else -K
p_v += V if REVERSE else -V
p_do += V if REVERSE else -V
p_dk += K if REVERSE else -K
p_dv += V if REVERSE else -V
if USE_GK:
p_gk += K if REVERSE else -K
if USE_GV:
p_gv += V if REVERSE else -V
class FusedRecurrentGatedABCFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, s, g, scale=None, initial_state=None, output_final_state=False, reverse=False):
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
# default scale
if scale is None:
scale = K ** -0.5
BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)
NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
num_stages = 1
num_warps = 1
g = g.float().exp()
final_state = (None, None)
if output_final_state:
final_state = (q.new_empty(B, H, K, M), q.new_empty(B, H, M, V))
ok = q.new_empty(NK, B, H, T, M, dtype=torch.float)
gk, gv = None, g
grid = (NM, NK, B * H)
fused_recurrent_gated_abc_fwd_kernel[grid](
q, k, s, gk, gv, ok, initial_state[0], final_state[0],
k.stride(1),
s.stride(1),
scale=scale,
B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
USE_INITIAL_STATE=initial_state[0] is not None,
STORE_FINAL_STATE=final_state[0] is not None,
USE_GK=False,
USE_GV=True,
REVERSE=reverse,
num_warps=num_warps,
num_stages=num_stages
)
ok = ok.sum(0)
qv = ok.softmax(-1, dtype=torch.float)
ov = q.new_empty(NM, B, H, T, V, dtype=torch.float)
gk, gv = g, None
grid = (NV, NM, B * H)
fused_recurrent_gated_abc_fwd_kernel[grid](
qv, s, v, gk, gv, ov, initial_state[1], final_state[1],
s.stride(1),
v.stride(1),
scale=1.,
B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
USE_INITIAL_STATE=initial_state[0] is not None,
STORE_FINAL_STATE=final_state[0] is not None,
USE_GK=True,
USE_GV=False,
REVERSE=reverse,
num_warps=num_warps,
num_stages=num_stages
)
ov = ov.sum(0)
ctx.save_for_backward(q, k, v, s, g, qv, *initial_state, ok)
ctx.scale = scale
ctx.reverse = reverse
# we do not need the gradient of the final state from the next chunk
# similiar to Trunctated BPTT
if final_state is not None:
final_state = tuple(i.detach() for i in final_state)
return ov.to(q.dtype), final_state
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, dht=None):
q, k, v, s, g, qv, *initial_state, ok = ctx.saved_tensors
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
V = v.shape[-1]
scale = ctx.scale
BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)
NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
num_stages = 1
num_warps = 1
dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
dv = q.new_empty(NM, B, H, T, V, dtype=torch.float)
gk, gv = g, None
grid = (NV, NM, B * H)
fused_recurrent_gated_abc_bwd_kernel[grid](
qv, s, v, gk, gv, do, dqv, dsv, dv, initial_state[1],
s.stride(1),
v.stride(1),
scale=1.,
B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state[1] is not None,
REVERSE=ctx.reverse,
USE_GK=gk is not None,
USE_GV=gv is not None
)
dqv = dqv.sum(0)
dsv = dsv.sum(0)
dv = dv.sum(0)
dgk = dqv * qv.float() - dsv * s.float()
dgk_cumsum = dgk.cumsum(-2)
dgk = dgk + dgk_cumsum[:, :, -1, None] - dgk_cumsum
dok = qv * (dqv - (qv * dqv).sum(-1, True))
dq = q.new_empty(NM, B, H, T, K, dtype=torch.float)
dk = q.new_empty(NM, B, H, T, K, dtype=torch.float)
dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float)
gk, gv = None, g
grid = (NM, NK, B * H)
fused_recurrent_gated_abc_bwd_kernel[grid](
q, k, s, gk, gv, dok, dq, dk, dsk, initial_state[0],
q.stride(1),
s.stride(1),
scale=scale,
B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state[0] is not None,
REVERSE=ctx.reverse,
USE_GK=gk is not None,
USE_GV=gv is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dsk = dsk.sum(0)
dgv = dok.float() * ok.float() - dsk * s.float()
dgv_cumsum = dgv.cumsum(-2)
dgv = dgv + dgv_cumsum[:, :, -1, None] - dgv_cumsum
ds = dsk.add_(dsv)
dg = dgk.add_(dgv)
return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, None, None, None
def fused_recurrent_gated_abc(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
s: torch.Tensor,
g: Optional[torch.Tensor] = None,
scale: Optional[int] = None,
initial_state: Optional[Tuple[torch.Tensor]] = None,
output_final_state: Optional[bool] = False
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `(B, H, T, K)`
k (torch.Tensor):
keys of shape `(B, H, T, K)`
v (torch.Tensor):
values of shape `(B, H, T, V)`
g (torch.Tensor):
Forget gates of shape `(B, H, T, M)` applied to keys.
If not provided, this function is equivalent to vanilla ABC.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[Tuple[torch.Tensor]]):
Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`.
"""
if initial_state is not None:
initial_state = tuple(i.detach() for i in initial_state)
if g is None:
# TODO: this 3 steps took huge amount of time, ought to be optimized
z = s.float().logcumsumexp(2)
g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
s = torch.exp(s - z).to(k.dtype)
if scale is None:
scale = q.shape[-1] ** -0.5
ov, final_state = FusedRecurrentGatedABCFunction.apply(q, k, v, s, g, scale, initial_state, output_final_state)
return ov, final_state

View File

@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from .chunk_fuse import fused_chunk_based
from .parallel import parallel_based
__all__ = [
'fused_chunk_based',
'parallel_based'
]

View File

@@ -0,0 +1,410 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.jit
def fused_chunk_based_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
z, # normalizer [B, H, L, 1]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# [BT, BT]
m_s = o_i[:, None] >= o_i[None, :]
# [BV], zero-order taylor expansion
b_h_0o = tl.zeros([BV], dtype=tl.float32)
# [BK, BV], first-order taylor expansion
b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
# [BK, BK, BV] second-order taylor expansion
b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
k_1o = tl.zeros([1, BK], dtype=tl.float32)
k_0o = 0
for i in range(0, tl.cdiv(T, BT)):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK*BK, BT]
b_k_2o = b_k[:, None, :] * b_k[None, :, :]
b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_z = tl.zeros([BT], dtype=tl.float32)
# interchunk
# zero-order
b_o += b_h_0o
b_z += k_0o
# first-order
b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
b_z += tl.sum(b_q * k_1o, axis=1)
# second-order
b_q_2o = b_q[:, :, None] * b_q[:, None, :]
b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
# update running statistics
k_1o += tl.sum(b_k, axis=1)[None, :]
k_2o += tl.sum(b_k_2o, axis=1)[None, :]
k_0o += BT
# intrachunk
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = 1 + b_s + 0.5 * b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_z += tl.sum(b_s, axis=1)
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
# [TB, BV]
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_z, b_z.to(p_z.dtype.element_ty),
mask=(i * BT + tl.arange(0, BT)) < T)
# update hidden state
# [BK, BV]
b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
p_z += BT
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_chunk_based_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
do, # gradient of output [B, H, L, D_head_V]
dz, # gradient of normalizer [B, H, L]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
# [BV], zero-order taylor expansion
# b_h_0o = tl.zeros([BV], dtype=tl.float32)
# [BK, BV], first-order taylor expansion
b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
# [BK, BK, BV] second-order taylor expansion
b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
k_1o = tl.zeros([1, BK], dtype=tl.float32)
k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(
k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(
v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h,
(T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
# load tensors
# [BT, BK]
b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# inter-chunk
b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
if i_v == 0:
b_dq += b_dz[:, None] * k_1o
b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
if i_v == 0:
b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
b_dq *= scale
# intra-chunk
# [BT, BT]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[:, None]
b_ds = tl.where(m_s, b_ds, 0) * scale
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
# store
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# update hidden state
# [BT, BK*BK]
b_k_2o = b_k[:, :, None] * b_k[:, None, :]
b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
# [BV, BK*BK]
b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
# [BV, BK]
b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
if i_v == 0:
# update running statistics
k_1o += tl.sum(b_k, axis=0)[None, :]
k_2o += tl.sum(b_k_2o, axis=0)[None, :]
tl.debug_barrier()
b_h_1o = None
b_h_2o = None
# [BK, BV], first-order taylor expansion
b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
# [BK, BK, BV] second-order taylor expansion
b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
b_dh_0o = tl.zeros([BV], dtype=tl.float32)
m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
dq_1o = tl.zeros([1, BK], dtype=tl.float32)
dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(
k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(
v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))
p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_q = (b_q * scale).to(b_k.dtype)
# intra chunk
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
if i_v == 0:
b_ds += b_dz[None, :]
b_ds = tl.where(m_s, b_ds, 0)
b_s = tl.dot(b_k, b_q, allow_tf32=False)
b_s2 = 1 + b_s + 0.5 * b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_s2 = tl.where(m_s, b_s2, 0)
b_ds *= (1+b_s)
b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
# inter chunk
b_k_2o = b_k[:, :, None] * b_k[:, None, :]
b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
b_dv += b_dh_0o
b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
if i_v == 0:
b_dk += dq_1o
b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype),
tl.trans(b_v), allow_tf32=False)
if i_v == 0:
b_dk_2o += dq_2o
b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
b_k_fp32 = tl.trans(b_k.to(tl.float32))
b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
b_dk += tl.trans(b_dk2)
# hidden state update
b_dh_0o += tl.sum(b_do, axis=0)
b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
b_q_2o = b_q[None, :, :] * b_q[:, None, :]
b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
if i_v == 0:
dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
class FusedChunkBasedFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, scale=1):
batch_size, n_heads, seq_len, d_head_qk = q.shape
# assert d_head_qk == 16, "currently we do not support feature dim other than 16"
d_head_v = v.shape[-1]
scale = scale
BT = 16
BK, BV = min(d_head_qk, 16), min(d_head_v, 32)
BK, BV = max(BK, 16), max(BV, 16)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_warps = 4
# the norm of o might explode, so we need to use float32 here
o = q.new_empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=torch.float32)
z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_based_fwd_kernel[grid](
q, k, v, o, z,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
)
o = o.sum(0)
z = z.sum(0)
ctx.save_for_backward(q, k, v)
ctx.scale = scale
return o.to(q.dtype), z.to(z.dtype)
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = ctx.scale
BT = 16
BK, BV = min(d_head_qk, 16), min(d_head_v, 32)
BK, BV = max(BK, 16), max(BV, 16)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 4
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_based_bwd_kernel[grid](
q, k, v, do, dz, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
triton_fused_chunk_based = FusedChunkBasedFunction.apply
def fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):
assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
if use_scale:
scale = q.shape[-1] ** -0.5
else:
scale = 1
o, z = triton_fused_chunk_based(q, k, v, scale)
if use_normalize:
o = o / (z[..., None] + 1e-6)
else:
o = o
return o.to(q.dtype)

132
finetune/lora/v6/fla/ops/based/naive.py vendored Normal file
View File

@@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
from fla.ops.based.chunk_fuse import fused_chunk_based
from fla.ops.based.parallel import parallel_based
def naive_parallel_based(q, k, v, use_scale=True, use_norm=True):
if use_scale:
q = q * (q.shape[-1] ** -0.5)
attn = q @ k.transpose(-2, -1)
attn = 1 + attn + 1/2 * (attn ** 2)
attn.masked_fill_(~torch.tril(torch.ones(
q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
o = attn @ v
if use_norm:
z = attn.sum(-1)
return o / (z[..., None] + 1e-6)
else:
return o
def naive_chunk_based(q, k, v, chunk_size=256):
q = q * (q.shape[-1] ** -0.5)
# compute normalizer.
k_cumsum = torch.cumsum(k, dim=-2)
kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
# first
z = (q * k_cumsum).sum(-1)
# second order
z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
# zero-th order
z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
# compute o
# constant term
_o = v.cumsum(-2)
q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
intra_chunk_attn = q @ k.transpose(-2, -1)
intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
intra_chunk_attn.masked_fill_(
~torch.tril(
torch.ones(chunk_size, chunk_size,
dtype=torch.bool, device=q.device),
), 0)
o = intra_chunk_attn @ v
# quadractic term
kv = torch.einsum(
'b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
kv = kv.cumsum(2)
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
# linear term
kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
kv = kv.cumsum(2)
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
o = rearrange(o, 'b h n c d -> b h (n c) d')
o = o + _o
return o / (z[..., None] + 1e-6)
if __name__ == "__main__":
B = 4
H = 4
L = 128
# D = 15
dtype = torch.float32
q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True)
do = torch.randn_like(v).cuda()
ref = naive_parallel_based(q, k, v, True, True)
ref.backward(do, retain_graph=True)
ref_dq, q.grad = q.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dv, v.grad = v.grad.clone(), None
# tri = naive_chunk_based(q, k, v)
# tri.backward(do, retain_graph=True)
# tri_dq, q.grad = q.grad.clone(), None
# tri_dk, k.grad = k.grad.clone(), None
# tri_dv, v.grad = v.grad.clone(), None
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
tri = fused_chunk_based(q, k, v, True, True)
tri.backward(do, retain_graph=True)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
print((ref-tri).abs().max())
print((ref_dq-tri_dq).abs().max())
print((ref_dk-tri_dk).abs().max())
print((ref_dv-tri_dv).abs().max())
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
tri = parallel_based(q, k, v, True, True)
tri.backward(do, retain_graph=True)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
print((ref-tri).abs().max())
print((ref_dq-tri_dq).abs().max())
print((ref_dk-tri_dk).abs().max())
print((ref_dv-tri_dv).abs().max())
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()

View File

@@ -0,0 +1,388 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# Based: An Educational and Effective Sequence Mixer
# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
@triton.jit
def parallel_based_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
z, # normalizer [B, H, L]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q
BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
):
# i_c: chunk index. used for sequence parallelism
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
NV = tl.cdiv(DV, BV)
i_k = i_kv // (NV)
i_v = i_kv % (NV)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))
# [BQ, BD] block Q, in the shared memory throughout the whole kernel
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
b_o = tl.zeros([BTL, BV], dtype=tl.float32)
b_z = tl.zeros([BTL], dtype=tl.float32)
# Q block and K block have no overlap
# no need for mask, thereby saving flops
for _ in range(0, i_c * BTL, BTS):
# [BK, BTS]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BTS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
b_s = tl.dot(b_q, (b_k), allow_tf32=False)
b_s = 1 + b_s + 0.5 * b_s * b_s
b_z += tl.sum(b_s, axis=1)
# [BQ, BD]
b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
p_k = tl.advance(p_k, (0, BTS))
p_v = tl.advance(p_v, (BTS, 0))
# # rescale interchunk output
tl.debug_barrier()
o_q = tl.arange(0, BTL)
# # sync threads, easy for compiler to optimize
# tl.debug_barrier()
o_k = tl.arange(0, BTS)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
# Q block and K block have overlap. masks required
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
# [BK, BTS]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BTS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
m_s = o_q[:, None] >= o_k[None, :]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = 1 + b_s + 0.5 * b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_z += tl.sum(b_s, axis=1)
# [BTL, BV]
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
p_k = tl.advance(p_k, (0, BTS))
p_v = tl.advance(p_v, (BTS, 0))
o_k += BTS
p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_z, b_z.to(p_z.dtype.element_ty),
mask=((i_c * BTL + tl.arange(0, BTL)) < T))
@triton.jit
def _parallel_based_bwd_dq(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_q = (b_q * scale).to(b_q.dtype)
b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
(s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))
p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
for _ in range(0, i_c * BTL, BTS):
# [BTS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BTS]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[:, None]
else:
b_ds = b_ds
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
# [BQ, BD]
b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)
p_k = tl.advance(p_k, (BTS, 0))
p_v = tl.advance(p_v, (0, BTS))
b_dq *= scale
o_q = tl.arange(0, BTL)
o_k = tl.arange(0, BTS)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
(s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
# Q block and K block have overlap. masks required
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
# [BTS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BTS]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
m_s = o_q[:, None] >= o_k[None, :]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[:, None]
else:
b_ds = b_ds
b_ds = tl.where(m_s, b_ds, 0) * scale
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
# [BTL, BK]
b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype),
b_k, allow_tf32=False)
p_k = tl.advance(p_k, (BTS, 0))
p_v = tl.advance(p_v, (0, BTS))
o_k += BTS
p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
return
@triton.jit
def _parallel_based_bwd_dkv(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
# compute dk dv
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
(i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(
p_v, boundary_check=(0, 1))
b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
[BTL, BV], dtype=tl.float32)
for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \
scale # [BTL, BTS]
b_s2 = 1 + b_s + 0.5 * b_s * b_s
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
if i_v == 0:
b_ds += b_dz[None, :] * scale
else:
b_ds = b_ds
b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),
tl.trans(b_q), allow_tf32=False)
tl.debug_barrier()
o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
# [BK, BQ]
m_s = o_k[:, None] <= o_q[None, :]
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s2 = 1 + b_s + 0.5 * b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_s2 = tl.where(m_s, b_s2, 0)
b_ds = tl.dot(b_v, b_do, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[None, :]
else:
b_ds = b_ds
b_ds = tl.where(m_s, b_ds, 0) * scale
# [BK, BD]
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),
tl.trans(b_q), allow_tf32=False)
o_q += BTS
p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,
(T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,
(T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
return
@triton.jit
def parallel_based_bwd_kernel(
q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
NV = tl.cdiv(DV, BV)
i_k = i_kv // (NV)
i_v = i_kv % (NV)
i_h = i_bh % H
_parallel_based_bwd_dq(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV
)
tl.debug_barrier()
_parallel_based_bwd_dkv(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV
)
class ParallelBasedFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, scale):
BTL, BTS = 128, 32
assert BTL % BTS == 0
# assert q.shape[-1] % 16 == 0
BK = min(128, triton.next_power_of_2(k.shape[-1]))
BV = min(128, triton.next_power_of_2(v.shape[-1]))
BK, BV = max(BK, 16), max(BV, 16)
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
num_stages = 2
num_warps = 4
NK = triton.cdiv(d_head_qk, BK)
NV = triton.cdiv(d_head_v, BV)
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
assert NK == 1, "will encounter some synchronization issue if not."
o = torch.empty(NK, batch_size, n_heads, seq_len,
d_head_v, device=q.device)
z = torch.empty(NK, batch_size, n_heads, seq_len,
device=q.device)
parallel_based_fwd_kernel[grid](
q, k, v, o, z,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
num_warps=num_warps,
num_stages=num_stages
)
ctx.save_for_backward(q, k, v)
ctx.scale = scale
return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
scale = ctx.scale
BTL, BTS = 64, 32
assert BTL % BTS == 0
BK = min(128, triton.next_power_of_2(k.shape[-1]))
BV = min(128, triton.next_power_of_2(v.shape[-1]))
BK, BV = max(BK, 16), max(BV, 16)
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
num_stages = 2
num_warps = 4
NK = triton.cdiv(d_head_qk, BK)
NV = triton.cdiv(d_head_v, BV)
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
assert NK == 1, "will encounter some synchronization issue if not"
dq = torch.empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=q.dtype, device=q.device)
dk = torch.empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=q.dtype, device=q.device)
dv = torch.empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=q.dtype, device=q.device)
parallel_based_bwd_kernel[grid](
q, k, v, do, dz, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
num_warps=num_warps,
num_stages=num_stages
)
return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
triton_parallel_based = ParallelBasedFunction.apply
def parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):
assert q.shape[-1] <= 128, "only support feature dim up to 128"
if use_scale:
scale = q.shape[-1] ** -0.5
else:
scale = 1
o, z = triton_parallel_based(q, k, v, scale)
if return_both:
return o, z
if use_normalize:
o = o / (z[..., None] + 1e-6)
else:
o = o
return o.to(q.dtype)

View File

@@ -0,0 +1,4 @@
- Delta Rule
The implementation of delta rule described in https://arxiv.org/abs/2102.11174

View File

@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
from .chunk_fuse import fused_chunk_delta_rule
from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule
from .chunk import chunk_delta_rule
__all__ = [
'fused_chunk_delta_rule',
'fused_recurrent_linear_attn_delta_rule',
'chunk_delta_rule'
]

View File

@@ -0,0 +1,544 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
import torch
import triton
import triton.language as tl
from fla.ops.utils import contiguous
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.ops.delta_rule.wy_fast import fwd_recompute_w_u, fwd_prepare_wy_repr, bwd_prepare_wy_repr
from fla.ops.delta_rule.chunk_fuse import fused_chunk_delta_rule_fwd, fused_chunk_delta_rule_bwd
# from fla.ops.delta_rule.utils import bwd_prepare_wy_repr
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_prepare_dv_kernel(
q,
k,
do,
dv,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
scale,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_k.dtype)
b_A += tl.dot(b_k, b_q, allow_tf32=False)
b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A , 0).to(do.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.dot(b_A, b_do, allow_tf32=False)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
def fwd_prepare_dv(q, k, do, BT):
dv = torch.empty_like(do)
B, H, T, K, V = *k.shape, do.shape[-1]
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
fwd_prepare_dv_kernel[(NT, B*H)](
q, k, do, dv,
k.stride(1), k.stride(2), k.stride(3),
do.stride(1), do.stride(2), do.stride(3),
T, K, V, K**-0.5, BT, BK, BV
)
return dv
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_delta_rule_fwd_kernel_h(
k,
v,
d,
v_new,
h,
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32)
# since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
for i_c in range(tl.cdiv(BT, BC)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BK]
b_d = tl.load(p_d, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)
# [BK, BV]
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
b_h += b_h_cumsum
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_linear_attn_fwd_kernel_o(
q,
k,
v,
h,
o,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_s = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_o += tl.dot(b_q, b_h, allow_tf32=False)
b_s += tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False))
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_delta_rule_bwd_kernel_dhu(
q,
k,
d,
do,
dh,
dv,
dv2,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_d = tl.load(p_d, boundary_check=(0, 1))
# [BT, V]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
# b_s = tl.dot(b_k, b_q, allow_tf32=False)
# b_s = tl.where(m_s, b_s, 0)
# b_dv = tl.dot(b_s.to(b_do.dtype), b_do, allow_tf32=False) + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BK, BV]
b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False)
b_dh += b_dh_tmp
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def chunk_delta_rule_bwd_kernel_dqkw(
q,
k,
v,
w,
h,
do,
dh,
dq,
dk,
dv,
dw,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
o_i = tl.arange(0, BT)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BT]
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dw += tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
# [BT, BT]
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
# [BT, BK]
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state):
B, H, T, K, V = *k.shape, u.shape[-1]
BK = triton.next_power_of_2(K)
assert BK <= 256, "current kernel does not support head dimension larger than 256."
BV = 16 if BK > 128 else 32
BV = 64 if BK <= 64 else BV
BC = 16 if BK > 128 else 32
BC = 64 if BK <= 64 else BC
BC = min(BT, BC)
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
h = k.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
v_new = torch.empty_like(u)
chunk_delta_rule_fwd_kernel_h[grid](
k, u, w, v_new, h, initial_state, final_state,
k.stride(1), k.stride(2), k.stride(3),
u.stride(1), u.stride(2), u.stride(3),
h.stride(1), h.stride(2),
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None,
)
return h, v_new
def chunk_bwd_dhu_fn(q, k, w, do, dv, BT):
B, H, T, K, V = *q.shape, do.shape[-1]
BK = triton.next_power_of_2(K)
assert BK <= 256, "current kernel does not support head dimension being larger than 256."
BV = 16 if BK > 128 else 32
BV = 64 if BK <= 64 else BV
BC = 16 if BK > 128 else 32
BC = 64 if BK <= 64 else BC
BC = min(BT, BC)
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
dh = q.new_empty(B, H, NT * K, V)
# dv_new = torch.empty_like(do)
grid = (NK, NV, B * H)
dv2 = torch.empty_like(dv)
chunk_delta_rule_bwd_kernel_dhu[grid](
q, k, w, do, dh, dv, dv2,
q.stride(1), q.stride(2), q.stride(3),
do.stride(1), do.stride(2), do.stride(3),
dh.stride(1), dh.stride(2),
K**-0.5,
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
)
return dh, dv2
def chunk_fwd_o_fn(q, k, v_new, h, BT):
B, H, T, K, V = *q.shape, v_new.shape[-1]
BK = triton.next_power_of_2(K)
o = torch.empty_like(v_new)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(K), 64)
NV = triton.cdiv(V, BV)
NT = triton.cdiv(T, BT)
grid = (NV, NT, B * H)
chunk_linear_attn_fwd_kernel_o[grid](
q, k, v_new, h, o,
q.stride(1), q.stride(2), q.stride(3),
v_new.stride(1), v_new.stride(2), v_new.stride(3),
h.stride(1), h.stride(2),
scale=K**-0.5,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
)
return o
def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT):
B, H, T, K, V = *q.shape, v_new.shape[-1]
BK = triton.next_power_of_2(K)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
NV = triton.cdiv(V, BV)
NT = triton.cdiv(T, BT)
grid = (NV, NT, B * H)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dw = torch.empty_like(w)
chunk_delta_rule_bwd_kernel_dqkw[grid](
q, k, v_new, w, h, do, dh, dq, dk, du, dw,
q.stride(1), q.stride(2), q.stride(3),
v_new.stride(1), v_new.stride(2), v_new.stride(3),
dh.stride(1), dh.stride(2),
scale = K ** -0.5,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
)
return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype)
class ChunkDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@contiguous
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):
### obtain WY representation. u is actually the new v.
w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
# ### forward_h
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)
## obtain output
o = chunk_fwd_o_fn(q, k, v_new, h, BT)
# save memory
if checkpoint_level == 1:
h, v_new = None, None
ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state)
ctx.BT = BT
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_ht=None):
q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
scale = q.shape[-1] ** -0.5
BT = ctx.BT
w, u = fwd_recompute_w_u(k, v, beta, A, BT)
# checkpont_level=1, recomputation.
if h is None:
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None)
dv = fwd_prepare_dv(q, k, do, BT)
dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)
dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)
dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT)
dk.add_(dk2)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None
def chunk_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
BT: int,
initial_state: torch.Tensor = None,
output_final_state: bool = False
):
assert q.dtype == k.dtype == v.dtype
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
return o, final_state

View File

@@ -0,0 +1,419 @@
# -*- coding: utf-8 -*-
from typing import Tuple
import torch
import triton
import triton.language as tl
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8)
],
key=["BT", "BK"],
)
@triton.jit
def fused_chunk_delta_rule_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_K]
v, # value [B, H, L, D_head_V]
v_new,
d, # decay [B, H, L, D_head_K]
o, # output [B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# [BT, BT]
m_s = o_i[:, None] >= o_i[None, :]
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i in range(0, tl.cdiv(T, BT)):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_d = tl.load(p_d, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_k.dtype)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
# [BT, BV]
b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)
b_v = b_v - b_v_prime
tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)
if CHECK and i == 0:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
else:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_v_new = tl.advance(p_v_new, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
p_d = tl.advance(p_d, (BT, 0))
if STORE_FINAL_STATE:
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fused_chunk_delta_rule_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
d, # decay [B, H, L, D_head_K]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
dd, # gradient of decay [NV, B, H, L, D_head_K]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# first reverse
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
m_s = o_i[:, None] <= o_i[None, :]
for i in range(1, tl.cdiv(T, BT) + 1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
# [DK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, DV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
# [BT, BT]
b_s = tl.dot(b_k, b_q, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
# [BT, DK]
b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
# [BT, DV]
b_dv = tl.dot(b_s, b_do, allow_tf32=False)
b_d = tl.load(p_d, boundary_check=(0, 1))
if CHECK and i == 1:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
else:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# sync threads
b_h = None
tl.debug_barrier()
m_s = o_i[:, None] >= o_i[None, :]
# [BV, BK]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
NT = tl.cdiv(T, BT)
for i in range(0, NT):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [DV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, DV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0)
# [BT, DK]
b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
# [DV, DK]
if CHECK and i == 0:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
else:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
if i < (NT - 1):
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
((i+1) * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))
def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BT = BT
# ctx.BT = BT
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
assert NK == 1, 'NK should be 1'
o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
else:
final_state = None
CHECK = True
# if version.parse(triton.__version__) < version.parse('2.2.0'):
# import warnings
# warnings.warn(
# "Triton<2.2.0 detected for running this kernel, "
# "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
# "that lead to significant precision loss. "
# "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
# "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
# )
# CHECK = True
grid = (NV, NK, batch_size * n_heads)
v_new = torch.empty_like(v)
fused_chunk_delta_rule_fwd_kernel[grid](
q, k, v, v_new, d, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
CHECK=CHECK,
)
return o, v_new, CHECK, final_state
def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
assert NK == 1
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_delta_rule_bwd_kernel[grid](
q, k, v, d, do, dq, dk, dv, dd, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
CHECK=CHECK,
# num_warps=num_warps,
# num_stages=num_stages
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
dd = dd.sum(0)
dd[:, :, 0:BT] = 0
return dq, dk, dv, dd
class FusedChunkDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0):
# lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory.
assert checkpoint_level in [0, 1]
k_origin = k
# k = _l2_norm_fwd(k_origin)
k = k
d, v_new = fwd_prepare_wy_repr(k, v, beta, BT)
o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state)
if checkpoint_level == 1:
d, v_new = None, None
ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state)
ctx.CHECK = CHECK
ctx.chunk_size = BT
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_final_state=None):
q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors
chunk_size = ctx.chunk_size
k = k_origin
# k = _l2_norm_fwd(k_origin)
if d is None:
d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state)
dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size)
dk.add_(dk2)
# dk = _l2_norm_bwd(k_origin, dk)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None
def fused_chunk_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
BT: int,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
return o, final_state
def delta_rule_recurrence(q, k, v, beta):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)
k = torch.nn.functional.normalize(k, p=2, dim=-1)
for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
return o
if __name__ == "__main__":
import torch.nn.functional as F
seq_len = 128
b = 2
h = 4
q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1)
beta = torch.rand(b, h, seq_len).sigmoid()
q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta))
do = torch.rand_like(v)
o2 = delta_rule_recurrence(q, k, v.clone(), beta)
o2.backward(do, retain_graph=True)
q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad
q.grad = k.grad = v.grad = beta.grad = None
o, _ = fused_chunk_delta_rule(q, k, v, beta, 32)
o.backward(do, retain_graph=True)
q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad
q.grad = k.grad = v.grad = beta.grad = None
print((o - o2).abs().max())
print((q_grad - q_grad2).abs().max())
print((k_grad - k_grad2).abs().max())
print((v_grad - v_grad2).abs().max())
print((beta_grad - beta_grad2).abs().max())

View File

@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
def delta_rule_recurrence(q, k, v, beta):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
o = torch.zeros_like(v)
S = torch.zeros(b, h, d_k, d_v).to(v)
q = q * (d_k ** -0.5)
for i in range(l):
_k = k[:, :, i]
_q = q[:, :, i]
_v = v[:, :, i].clone()
beta_i = beta[:, :, i]
_v = _v - (S.clone() * _k[..., None]).sum(-2)
_v = _v * beta_i[..., None]
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
return o
def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
b, h, l, d_k = q.shape
d_v = v.shape[-1]
q = q * (d_k ** -0.5)
v = v * beta[..., None]
k_beta = k * beta[..., None]
assert l % chunk_size == 0
# note that diagonal is masked.
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
for i in range(1, chunk_size):
attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
# u
k_cumsum = attn @ v
# w
k_cumdecay = attn @ k_beta
v = k_cumsum
S = k.new_zeros(b, h, d_k, d_v)
o = torch.zeros_like(v)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
for i in range(0, l // chunk_size):
q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
v_prime = k_cumdecay[:, :, i] @ S
v_new = v_i - v_prime
o_inter = q_i @ S
o[:, :, i] = o_inter + attn @ v_new
# chunk state update
S = S + k_i.transpose(-1, -2) @ v_new
return rearrange(o, 'b h n c d -> b h (n c) d')
if __name__ == '__main__':
B = 2
H = 4
L = 256
DK = 128
DV = 128
q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True)
k = (torch.randn(B, H, L, DK)).cuda()
k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True)
v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True)
beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True)
o = delta_rule_recurrence(q, k, v, beta)
do = torch.randn(B, H, L, DV).cuda()
o.backward(do, retain_graph=True)
q_grad, q.grad = q.grad, None
k_grad, k.grad = k.grad, None
v_grad, v.grad = v.grad, None
beta_grad, beta.grad = beta.grad, None
o2 = delta_rule_chunkwise(q, k, v, beta)
o2.backward(do)
assert torch.allclose(o, o2, atol=1e-4), breakpoint()
assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint()
assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint()
assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint()
assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint()
print("All passed!")

View File

@@ -0,0 +1,312 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.jit
def fused_recurrent_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V].
beta, # beta [B, H, L]
o, # output [B, H, L, D_head_V]
initial_state,
final_state, # final hidden state [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_beta = beta + i_bh * T
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
mask_kv = mask_bk[None, :] & mask_bv[:, None]
h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_v_minus = tl.sum(h * _k[None, :], axis=1)
_v -= _v_minus
_beta = tl.load(p_beta).to(tl.float32)
# in-place overwrite
tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv)
_v *= _beta
h += _k[None, :] * _v[:, None]
_o = h * _q[None, :]
_o = tl.sum(_o, axis=1)
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += DK
p_k += DK
p_o += DV
p_v += DV
p_beta += 1
if STORE_FINAL_STATE:
p_final_s = final_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
beta, # beta [B, H, L]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
dbeta, # gradient of beta [B, H, L]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_beta = beta + i_bh * T + T - 1
p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
BK + tl.arange(0, BK) + (T - 1) * DK
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
BV + tl.arange(0, BV) + (T - 1) * DV
d_h = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_beta = tl.load(p_beta).to(tl.float32)
d_h += _q[:, None] * _do[None, :]
d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1)
d_v = tl.sum(d_h * _k[:, None], axis=0)
d_beta = tl.sum(d_v * _v)
d_v = d_v * _beta
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))
d_h -= _k[:, None] * d_v[None, :]
p_do -= DV
p_q -= DK
p_k -= DK
p_v -= DV
p_dk -= DK
p_dv -= DV
p_dbeta -= 1
p_beta -= 1
tl.debug_barrier()
h = tl.zeros([BK, BV], dtype=tl.float32)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_beta = beta + i_bh * T
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK
if USE_INITIAL_STATE:
mask_kv = mask_bk[:, None] & mask_bv[None, :]
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[:, None]) * \
DV + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for i in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_beta = tl.load(p_beta).to(tl.float32)
_v *= _beta
h += _k[:, None] * _v[None, :]
_d_q = h * _do[None, :]
d_q = tl.sum(_d_q, axis=1) * scale
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
if i < T - 1:
d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)
d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)
d_k -= tl.sum(d_v[None, :] * h, axis=1)
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
p_k += DK
p_do += DV
p_v += DV
p_dk += DK
p_dv += DV
p_dq += DK
p_beta += 1
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
assert NK == 1, "NK > 1 is not supported yet"
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
else:
final_state = None
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_fwd_kernel[grid](
q, k, v, beta, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, beta, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, d_final_state=None):
q, k, v, beta, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 1
num_warps = 2
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
dbeta = q.new_empty(NV, batch_size, n_heads, seq_len)
fused_recurrent_bwd_kernel[grid](
q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
dbeta = dbeta.sum(0)
return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None
def fused_recurrent_linear_attn_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
if beta is None:
beta = torch.ones_like(q[..., 0])
o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state)
return o, final_state

View File

@@ -0,0 +1,297 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
from fla.ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
# o: cumprod
# o2: cumprodsum
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_prepare_wy_repr_kernel(
k,
v,
beta,
o,
o2,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
mask_bt = (tl.arange(0, BT) + i_t * BT) < T
mask_bk = tl.arange(0, BK) < K
mask_bv = tl.arange(0, BV) < V
mask_bk = mask_bk[None, :] & mask_bt[:, None]
mask_bv = mask_bv[None, :] & mask_bt[:, None]
# [BT, BK]
b_k = tl.load(p_k, mask=mask_bk, other=0)
# [BT,]
b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)
# [BT, BV]
b_v = tl.load(p_v, mask=mask_bv, other=0)
b_v = (b_v * b_beta[:, None]).to(b_v.dtype)
# [BT, BK]
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
# [BT, BT]
b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
for i in range(BT):
mask = tl.arange(0, BT) == i
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
b_A = tl.where(mask[:, None], b_a, b_A)
b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
b_A = b_A.to(b_k.dtype)
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
b_u = tl.dot(b_A, b_v, allow_tf32=False)
p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)
p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def bwd_prepare_wy_repr_kernel(
k, v, beta,
o, o2, do, do2,
dk, dv, dbeta,
NT, K, V, T,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
mask_bt = (tl.arange(0, BT) + i_t * BT) < T
mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]
mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]
b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)
b_beta = b_beta.to(tl.float32)
A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]
A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)
b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)
b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)
dA = tl.zeros([BT, BT], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
for i in range(BT-1, -1, -1):
mask = tl.arange(0, BT) == i
attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)
do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)
dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)
b_do = b_do - attn[:, None] * do_[None, :]
b_dv = b_dv - attn[:, None] * dv_[None, :]
tl.debug_barrier()
p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
b_v = tl.load(p_v, mask=mask_bv)
b_dk += b_do * b_beta[:, None]
b_dbeta = tl.sum(b_do * b_k, axis=1)
b_dbeta += tl.sum(b_dv * b_v, axis=1)
b_v = None
p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
b_o = tl.load(p_o, mask=mask_bk)
b_o2 = tl.load(p_o2, mask=mask_bv)
dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)
dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),
allow_tf32=False)
dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)
b_dv *= b_beta[:, None]
p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)
dA = dA * b_beta[:, None]
b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)
b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)
p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)
def fwd_prepare_wy_repr(k, v, beta, chunk_size):
B, H, T, K, V = *k.shape, v.shape[-1]
v_new = torch.empty_like(v)
o_cumdecay = torch.empty_like(k)
BT = chunk_size
NT = triton.cdiv(T, BT)
BK = triton.next_power_of_2(K)
BV = triton.next_power_of_2(V)
fwd_prepare_wy_repr_kernel[(NT, B*H)](
k, v, beta, o_cumdecay, v_new,
T, K, V, BT, BK, BV
)
return o_cumdecay, v_new
def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):
b, h, l, d_k = do.shape
d_v = v.shape[-1]
BK = triton.next_power_of_2(d_k)
BV = triton.next_power_of_2(d_v)
c = chunk_size
BK = d_k
NT = triton.cdiv(l, c)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
dbeta = torch.zeros_like(beta)
bwd_prepare_wy_repr_kernel[(NT, b*h)](
k, v, beta,
o_cumdecay, v_new, do, do2,
dk, dv, dbeta,
NT, d_k, d_v, l, chunk_size, BK, BV
)
return dk, dv, dbeta
class WYRepresentationPrepration(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, k, v, beta, chunk_size):
o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
ctx.chunk_size = chunk_size
ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new)
return o_cumdecay, v_new
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, do2):
k, v, beta, o_cumdecay, v_new = ctx.saved_tensors
dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size)
return dk, dv, dbeta, None
prepare_wy_repr = WYRepresentationPrepration.apply
def naive(k, v, beta, chunk_size):
l_org = k.shape[2]
l_new = triton.next_power_of_2(l_org)
# pad k, v, beta
k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
# k = torch.nn.functional.normalize(k, dim=-1, p=2)
beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
k_beta = k * beta[..., None]
v = v * beta[..., None]
attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
attn = attn * beta[..., None]
x = attn @ v
o = torch.zeros_like(k)
o2 = torch.zeros_like(v)
o[..., 0, :] = k_beta[..., 0, :].clone()
o2[..., 0, :] = x[..., 0, :].clone()
for i in range(1, chunk_size):
o_i = (o[..., :i, :]).clone()
o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
o2_i = (o2[..., :i, :]).clone()
o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
seq_len = 2048
b = 4
h = 8
k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2)
v = torch.randn(b, h, seq_len, 256)
beta = torch.rand(b, h, seq_len).sigmoid()
require_grad = True
k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
do = torch.rand_like(k)
do2 = torch.rand_like(v)
print("Start warmup.")
o1, o2 = prepare_wy_repr(k, v, beta, 32)
# (o1 * do + o2 * do2).sum().backward()
o3, o4 = prepare_wy_repr2(k, v, beta, 32)
# (o1 * do + o2 * do2).sum().backward()
print((o1 - o3).abs().max())
print((o2 - o4).abs().max())
for i in range(30):
o1, o2 = prepare_wy_repr(k, v, beta, 32)
(o1 * do + o2 * do2).sum().backward()
o1, o2 = prepare_wy_repr2(k, v, beta, 32)
(o1 * do + o2 * do2).sum().backward()
print("Done warmup.")
import time
torch.cuda.synchronize()
start = time.time()
for i in range(200):
o1, o2 = prepare_wy_repr(k, v, beta, 64)
(o1 * do + o2 * do2).sum().backward()
torch.cuda.synchronize()
print(time.time() - start)
torch.cuda.synchronize()
start = time.time()
for i in range(200):
o1, o2 = prepare_wy_repr2(k, v, beta, 64)
(o1 * do + o2 * do2).sum().backward()
torch.cuda.synchronize()
print(time.time() - start)

View File

@@ -0,0 +1,401 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from einops import rearrange
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
# o: cumprod
# o2: cumprodsum
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_prepare_wy_repr_kernel(
k,
v,
beta,
w,
u,
A,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
for i in range(1, BT):
mask = tl.arange(0, BT) == i
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
b_A = tl.where(mask[:, None], b_a, b_A)
b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))
b_A = b_A.to(k.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def fwd_recompute_w_u_kernel(
k,
v,
beta,
w,
u,
A,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["BT", "BK", "BV"],
)
@triton.jit
def bwd_prepare_wy_repr_kernel(
k, v, beta, A,
dw, du,
dk, dv, dbeta,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
T,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
b_dbeta = tl.zeros([BT], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
b_du = tl.load(p_du, boundary_check=(0, 1))
b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)
b_dv = b_dv_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
# store
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
b_A2 = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_dw = tl.load(p_dw, boundary_check=(0, 1))
b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)
b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)
b_dk = b_dk_beta * b_beta[:, None]
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
# store
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])
b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
tl.debug_barrier()
for i in range(BT-1, 0, -1):
mask = tl.arange(0, BT) == i
b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0)
b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1)
b_dA = tl.where(mask[:, None], b_da2, b_dA)
b_dA += b_da[None, :] * b_a[:, None]
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
tl.debug_barrier()
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dk = tl.load(p_dk, boundary_check=(0, 1))
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
b_dk += b_dk_beta * b_beta[:, None]
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty),boundary_check=(0,))
def fwd_prepare_wy_repr(k, v, beta, BT):
B, H, T, K, V = *k.shape, v.shape[-1]
u = torch.empty_like(v)
w = torch.empty_like(k)
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)
fwd_prepare_wy_repr_kernel[(NT, B*H)](
k, v, beta, w, u, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
)
return w, u, A
def fwd_recompute_w_u(k, v, beta, A, BT):
B, H, T, K, V = *k.shape, v.shape[-1]
u = torch.empty_like(v)
w = torch.empty_like(k)
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
fwd_recompute_w_u_kernel[(NT, B*H)](
k, v, beta, w, u, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
)
return w, u
def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):
B, H, T, K, V = *k.shape, v.shape[-1]
NT = triton.cdiv(T, BT)
BK = min(triton.next_power_of_2(K), 64)
BV = min(triton.next_power_of_2(V), 64)
NT = triton.cdiv(T, BT)
dk = torch.empty_like(k)
dv = torch.empty_like(v).contiguous()
dbeta = torch.zeros_like(beta)
bwd_prepare_wy_repr_kernel[(NT, B*H)](
k, v, beta, A,
dw, du,
dk, dv, dbeta,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
T, K, V, BT, BK, BV
)
return dk, dv, dbeta
class WYRepresentationPrepration(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, k, v, beta, chunk_size):
ctx.BT = chunk_size
w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)
ctx.save_for_backward(k, v, beta, A)
return w, u
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, dw, du):
k, v, beta, A = ctx.saved_tensors
BT = ctx.BT
dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT)
return dk, dv, dbeta, None
prepare_wy_repr = WYRepresentationPrepration.apply
def naive(k, v, beta, chunk_size):
l_org = k.shape[2]
l_new = triton.next_power_of_2(l_org)
# pad k, v, beta
k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
# k = torch.nn.functional.normalize(k, dim=-1, p=2)
beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
k_beta = k * beta[..., None]
v = v * beta[..., None]
attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
attn = attn * beta[..., None]
x = attn @ v
o = torch.zeros_like(k)
o2 = torch.zeros_like(v)
o[..., 0, :] = k_beta[..., 0, :].clone()
o2[..., 0, :] = x[..., 0, :].clone()
for i in range(1, chunk_size):
o_i = (o[..., :i, :]).clone()
o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
o2_i = (o2[..., :i, :]).clone()
o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
if __name__ == "__main__":
torch.set_default_dtype(torch.float32)
seq_len = 1024
b = 4
h = 4
k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)
v = torch.randn(b, h, seq_len, 128)
beta = torch.rand(b, h, seq_len).sigmoid()
# beta = torch.ones(b, h, seq_len)
require_grad = True
k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
do = torch.rand_like(k)
do2 = torch.rand_like(v)
o1, o2 = naive(k.clone(), v.clone(), beta.clone(), 64)
if require_grad:
o1.backward(do, retain_graph=True)
o2.backward(do2, retain_graph=True)
k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad
k.grad = v.grad = beta.grad = None
o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone())
print((o1-o3).abs().max())
print((o2-o4).abs().max())
if require_grad:
o3.backward(do, retain_graph=True)
o4.backward(do2, retain_graph=True)
k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad
print((k_grad2-k_grad).abs().max())
print((v_grad2-v_grad).abs().max())
print((beta_grad2-beta_grad).abs().max())
breakpoint()

View File

@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_gla
from .chunk_fuse import fused_chunk_gla
from .recurrent_fuse import fused_recurrent_gla
__all__ = [
'chunk_gla',
'fused_chunk_gla',
'fused_recurrent_gla'
]

734
finetune/lora/v6/fla/ops/gla/chunk.py vendored Normal file
View File

@@ -0,0 +1,734 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from fla.ops.utils import chunk_reversed_cumsum_fwd
from fla.utils import contiguous
@triton.autotune(
configs=[
triton.Config({'BS': 16}, num_warps=2),
triton.Config({'BS': 16}, num_warps=4),
triton.Config({'BS': 16}, num_warps=8),
triton.Config({'BS': 32}, num_warps=2),
triton.Config({'BS': 32}, num_warps=4),
triton.Config({'BS': 32}, num_warps=8),
triton.Config({'BS': 64}, num_warps=2),
triton.Config({'BS': 64}, num_warps=4),
triton.Config({'BS': 64}, num_warps=8),
],
key=['S']
)
@triton.jit
def chunk_gla_fwd_kernel_cum(
s,
o,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr
):
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_gla_fwd_kernel_h(
k,
v,
g,
h,
h0,
ht,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BK, BT]
b_g = tl.load(p_g, boundary_check=(0, 1))
if i_t < NT - 1:
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
else:
b_gn = tl.min(b_g, axis=1)
b_h *= tl.exp(b_gn)[:, None]
b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
b_h += tl.dot(b_k, b_v, allow_tf32=False)
if STORE_FINAL_STATE:
p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_gla_fwd_kernel_intra(
q,
k,
g,
A,
s_k_h,
s_k_t,
s_k_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
n_bh = tl.num_programs(2)
if i_i > i_j:
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)
# [BK, BC]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
# [BC, BC]
b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
elif i_i == i_j:
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
o_i = tl.arange(0, BC)
o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
for j in range(0, BC):
# [BK,]
b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
# [BC,]
b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)
b_A = tl.where(o_i >= j, b_A, 0.)
tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
p_k = tl.advance(p_k, (K,))
p_gk = tl.advance(p_gk, (K,))
@triton.jit
def chunk_gla_fwd_kernel_inter(
q,
v,
g,
h,
o,
A,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_o = tl.zeros([BT, BV], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
# [BT, BK]
b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
# works but dkw, owing to divine benevolence
# [BT, BV]
if i_k >= 0:
b_o += tl.dot(b_qg, b_h, allow_tf32=False)
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BT]
b_A = tl.load(p_A, boundary_check=(0, 1))
b_o += tl.dot(b_A, b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_gla_bwd_kernel_dh(
q,
g,
do,
dh,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BK, BV]
b_dh *= tl.exp(b_gn)[:, None]
# [BK, BT]
b_g = tl.load(p_g, boundary_check=(0, 1))
b_q = (b_q * tl.exp(b_g)).to(b_q.dtype)
# [BK, BV]
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
@triton.jit
def chunk_gla_bwd_kernel_inter(
k,
v,
h,
g,
A,
do,
dh,
dq,
dk,
dv,
dA,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
# [BT, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
b_k = (b_k * b_gn).to(b_k.dtype)
# [BT, BT]
b_A = tl.load(p_A, boundary_check=(0, 1))
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
if i_k == 0:
b_dv += tl.dot(b_A, b_do, allow_tf32=False)
b_do = (b_do * scale).to(b_do.dtype)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BT, BT]
b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
# [BT, BK]
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
b_dq = b_dq * tl.exp(b_gk)
b_dk = b_dk * b_gn
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
# [BT, BT]
b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
if i_k == 0:
tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_gla_bwd_kernel_intra(
q,
k,
g,
dA,
dq,
dk,
dg,
s_k_h,
s_k_t,
s_k_d,
T: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_t, i_i = i_c // NC, i_c % NC
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BC, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
b_dq = tl.zeros([BC, BK], dtype=tl.float32)
for i_j in range(0, i_i):
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
# [BC, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
# [BC, BC]
b_dA = tl.load(p_dA, boundary_check=(0, 1))
# [BC, BK]
b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
b_dq *= tl.exp(b_g - b_gn[None, :])
o_i = tl.arange(0, BC)
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
for j in range(0, BC):
p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
p_gkj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
# [BC,]
b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
# [BK,]
b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32)
# [BC, BK]
m_i = o_i[:, None] >= j
# [BC, BK]
b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.)
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BC, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_dk = tl.zeros([BC, BK], dtype=tl.float32)
for i_j in range(i_i + 1, NC):
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype)
# [BC, BC]
b_dA = tl.load(p_dA, boundary_check=(0, 1))
# [BC, BK]
b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
b_dk *= tl.exp(b_gn[None, :] - b_gk)
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
for j in range(0, BC):
p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
p_gqj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
# [BC,]
b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
# [BK,]
b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
# [BC, BK]
m_i = o_i[:, None] <= j
b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
b_dg = b_q * b_dq - b_k * b_dk
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
class ChunkGLAFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level):
B, H, T, K, V = *q.shape, v.shape[-1]
BT, BC = 64, 16
BK = min(64, triton.next_power_of_2(K))
BV = min(64, triton.next_power_of_2(V))
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
num_warps = 4 if BK == 64 else 2
num_stages = 1
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
h = q.new_empty(B, H, NT * K, V)
grid = (NV, NK, B * H)
chunk_gla_fwd_kernel_h[grid](
k, v, g, h, h0, ht,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=h0 is not None,
STORE_FINAL_STATE=ht is not None,
num_warps=num_warps,
num_stages=num_stages
)
return h
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float)
g_org, g = g, torch.empty_like(g, dtype=torch.float)
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
# keep cummulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_gla_fwd_kernel_cum[grid](
g_org, g,
g.stride(1), g.stride(2), g.stride(3),
T=T, S=K, BT=BT
)
h = fwd_inner(
q=q, k=k, v=v, g=g,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
h0=initial_state if initial_state is not None else None,
ht=final_state if final_state is not None else None
)
A = q.new_zeros(NK, B, H, T, BT)
grid = (NK, NT * NC * NC, B * H)
chunk_gla_fwd_kernel_intra[grid](
q, k, g, A,
k.stride(1), k.stride(2), k.stride(3),
scale,
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
num_warps=num_warps,
num_stages=num_stages
)
A = A.sum(0, dtype=A.dtype)
o = torch.empty_like(v)
grid = (NV, NT, B * H)
chunk_gla_fwd_kernel_inter[grid](
q, v, g, h, o, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
scale,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
if checkpoint_level >= 1:
del g
g = g_org
if checkpoint_level > 1:
del h
h, initial_state = None, None
ctx.save_for_backward(q, k, v, g, h, initial_state, A)
ctx.BT = BT
ctx.scale = scale
ctx.checkpoint_level = checkpoint_level
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, dht=None):
q, k, v, g, h, initial_state, A = ctx.saved_tensors
B, H, T, K, V = *q.shape, v.shape[-1]
BT, BC = ctx.BT, 16
BK = min(64, triton.next_power_of_2(K))
BV = min(64, triton.next_power_of_2(V))
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
NK = triton.cdiv(K, BK)
num_warps = 4 if BK == 64 else 2
num_stages = 1
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
h = q.new_empty(B, H, NT * K, V)
grid = (NV, NK, B * H)
chunk_gla_fwd_kernel_h[grid](
k, v, g, h, h0, ht,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=h0 is not None,
STORE_FINAL_STATE=ht is not None,
num_warps=num_warps,
num_stages=num_stages
)
return h
def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, NT, scale):
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
dh = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_gla_bwd_kernel_dh[grid](
q, g, do, dh,
q.stride(1), q.stride(2), q.stride(3),
do.stride(1), do.stride(2), do.stride(3),
dh.stride(1), dh.stride(2), dh.stride(3),
scale,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
return dh
if ctx.checkpoint_level >= 1:
# save the original g and compute its fp32 cumsum during the backward pass for memory consideration
g_org, g = g, torch.zeros_like(g, dtype=torch.float)
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
# keep cummulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_gla_fwd_kernel_cum[grid](
g_org, g,
g.stride(1), g.stride(2), g.stride(3),
T=T, S=K, BT=BT
)
# rerun the forward pass to get h if checkpoint_level >= 1
if ctx.checkpoint_level > 1:
h = fwd_inner(
q=q, k=k, v=v, g=g,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
h0=initial_state if initial_state is not None else None,
ht=None
)
scale = ctx.scale
dh = bwd_inner(
q, g, do,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
scale=scale
)
dq = torch.empty_like(q, dtype=torch.float)
dk = torch.empty_like(k, dtype=torch.float)
dg = torch.empty_like(k, dtype=torch.float)
dv = v.new_empty(NK, *v.shape)
dA = q.new_zeros(B, H, T, BT)
grid = (NK, NT, B * H)
chunk_gla_bwd_kernel_inter[grid](
k, v, h, g, A, do, dh, dq, dk, dv, dA,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
scale,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
dv = dv.sum(0, dtype=dv.dtype)
grid = (NK, NT * NC, B * H)
chunk_gla_bwd_kernel_intra[grid](
q, k, g, dA, dq, dk, dg,
k.stride(1), k.stride(2), k.stride(3),
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
num_warps=num_warps,
num_stages=num_stages
)
dq = dq.to(q.dtype)
dk = dk.to(q.dtype)
# reversed cumsum, equivalent to:
#
# def reversed_cumsum(x, dim=-1):
# c = x.cumsum(dim)
# return x + c.index_select(dim, x.new_tensor([c.shape[dim]-1], dtype=torch.long)) - c
dg = chunk_reversed_cumsum_fwd(dg).to(k.dtype)
return dq, dk, dv, dg, None, None, None, None
def chunk_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
scale: Optional[int] = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
checkpoint_level: Optional[int] = 2
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `(B, H, T, K)`
k (torch.Tensor):
keys of shape `(B, H, T, K)`
v (torch.Tensor):
values of shape `(B, H, T, V)`
g (torch.Tensor):
Forget gates of shape `(B, H, T, K)` applied to keys.
scale (Optional[int]):
Scale factor for the GLA attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `(B, H, K, V)`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
checkpoint_level (Optional[int]):
Checkpointing level; higher values will save more memories and do more recomputations during backward.
Default: `0`:
- Level `0`: no memory saved, no recomputation.
- Level `1`: recompute the fp32 cumulative values during backward.
- Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
"""
assert checkpoint_level in [0, 1, 2]
if scale is None:
scale = q.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)
return o, final_state

View File

@@ -0,0 +1,548 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Songlin Yang
# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635
# on-the-fly computation without materializing hidden statets into HBMs
from typing import Tuple
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.ops.gla.chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum,
prepare_qg_kg)
from fla.utils import contiguous
inv_ln2 = 1.44269504
@triton.jit
def fused_chunk_gla_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_K]
v, # value [B, H, L, D_head_V]
g, # cumulative sum of log decay [B, H, L, D_head_K]
o, # output [B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_h = tl.zeros([BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
for i in range(0, tl.cdiv(T, BT)):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
if CHECK and i == 0:
b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
else:
b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
p_db += BT * DK
if STORE_FINAL_STATE:
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_chunk_gla_bwd_kernel(
q, k, v, g,
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BV, BK]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
for i in range(0, tl.cdiv(T, BT)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2
d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
# [DV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, DV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [DV, DK]
if CHECK and i == 0:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
else:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# sync threads
b_h = None
tl.debug_barrier()
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
# cum = tl.zeros([BK], dtype=tl.float32)
for i in range(1, tl.cdiv(T, BT) + 1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
# [DK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, DV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)
# inter-chunk
# [DK, DV]
if CHECK and i == 1:
b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
else:
b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def fwd_inner_chunk(
q, k, g, A,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
DK: tl.constexpr, # D_head_K
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
o_i = tl.arange(0, BT)
p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
for i in range(BT):
_q = tl.load(p_q, mask=mask, other=0) * scale
gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
s = _q[None, :] * b_k * tl.math.exp2(gq[None, :] - b_g)
score = tl.sum(s, axis=1)
score = tl.where(o_i <= i, score, 0)
tl.store(p_A, score.to(p_A.dtype.element_ty))
p_q += DK
p_gq += DK
p_A += BT
@triton.jit
def bwd_inner_chunk(
q,
k,
g,
dA,
dq,
dk,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
DK: tl.constexpr, # D_head_K
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
o_i = tl.arange(0, BT)
p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
for i in range(BT):
_q = tl.load(p_q, mask=mask, other=0)
gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
score = tl.math.exp2(gq[None, :] - b_g)
score = tl.where(o_i[:, None] <= i, score, 0)
_dA = tl.load(p_dA)
_dA = tl.where(o_i <= i, _dA, 0)
b_dk += (_dA[:, None] * score * _q[None, :])
b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0)
tl.store(p_dq, b_dq, mask=mask)
p_q += DK
p_dq += DK
p_gq += DK
p_dA += BT
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
class FusedChunkGLAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
ctx.g_dtype = g.dtype
g_original = g
# cumulative decay should be in float32, otherwise the err will be accumulated and amplified.
g = torch.empty_like(g, dtype=torch.float32)
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
ctx.scale = scale
# inter-chunk
BT = 16 # chunk_size
BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
num_stages = 1
num_warps = 2
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
q_g = torch.empty_like(q)
k_g = torch.empty_like(k)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
fwd_decay_cumsum[grid](
g_original,
g,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
)
prepare_qg_kg[grid](
q, k, g, q_g, k_g,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float, requires_grad=False)
else:
final_state = None
# the bug still exists even for Triton 2.2 on H100 GPUs
# so we always enable initial checks
CHECK = True
if version.parse(triton.__version__) < version.parse('2.2.0'):
import warnings
warnings.warn(
"Triton<2.2.0 detected for running this kernel, "
"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
"that lead to significant precision loss. "
"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
)
CHECK = True
grid = (NV, NK, batch_size * n_heads)
fused_chunk_gla_fwd_kernel[grid](
q_g, k_g, v, g, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
CHECK=CHECK,
num_warps=num_warps,
num_stages=num_stages
)
o = o.sum(0)
# intra-chunk
chunk_size = 16
num_chunk = seq_len // chunk_size
v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
BK = min(d_head_qk, 64)
NK = triton.cdiv(d_head_qk, BK)
A = q.new_empty(NK, batch_size, n_heads, triton.cdiv(seq_len, BT), BT, BT)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
fwd_inner_chunk[grid](
q, k, g, A,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_stages=3,
num_warps=4
)
A = A.sum(0)
o2 = A @ v2
o2 = rearrange(o2, 'b h n c d -> b h (n c) d')
# combine inner and inter
o.add_(o2)
ctx.save_for_backward(q, k, v, g_original, A, initial_state)
ctx.CHECK = CHECK
return o.to(v), final_state
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, d_final_state=None):
q, k, v, g_origin, A, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = ctx.scale
# recomputation
# inter-chunk
BT = 16 # chunk_size
g = torch.empty_like(g_origin, dtype=torch.float32)
BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
q_g = torch.empty_like(q)
k_g = torch.empty_like(k)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
fwd_decay_cumsum[grid](
g_origin,
g,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
)
prepare_qg_kg[grid](
q, k, g, q_g, k_g,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
)
# inter-chunk
BT = 16
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 2
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_gla_bwd_kernel[grid](
q_g, k_g, v, g, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
# clamp_min=-3,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
CHECK=ctx.CHECK,
num_warps=num_warps,
num_stages=num_stages,
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
# intra chunk
num_chunk = seq_len // BT
v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk)
dA2 = (do2 @ v2.transpose(-2, -1)) * scale
dv2 = A.transpose(-1, -2) @ do2
dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk)
BK = min(triton.next_power_of_2(d_head_qk), 16)
NK = triton.cdiv(d_head_qk, BK)
dk2 = torch.empty_like(k)
dq2 = torch.empty_like(q)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
bwd_inner_chunk[grid](
q, k, g,
dA2, dq2, dk2,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, BK=BK,
num_warps=1,
num_stages=3
)
BK = min(triton.next_power_of_2(d_head_qk), 32)
NK = triton.cdiv(d_head_qk, BK)
dg = torch.empty_like(g, dtype=torch.float32)
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
bwd_decay_global_cumsum[grid](
dq2, dq, dk2, dk, q, k, g, dg,
q.stride(1), q.stride(2), q.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, BK=BK,
num_warps=1,
num_stages=1
)
dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)
def rev_cumsum_exclusive(x):
cumsum_x = x.cumsum(-2)
rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x
return rev_cumsum_x
rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])
dg.add_(rev_cumsum_dg.unsqueeze(-2))
dv.add_(dv2)
dg = rearrange(dg, 'b h n c d -> b h (n c) d')
return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None
def pad(x, chunk_size=16):
seq_len = x.shape[-2]
padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size
if x.shape[-2] % chunk_size != 0:
x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len))
return x
def ceildiv(a, b):
return -(a // -b)
def fused_chunk_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
scale: int = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if scale == -1:
scale = q.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
seq_len = q.shape[-2]
q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
o, final_state = FusedChunkGLAFunction.apply(
q, k, v, g, scale, initial_state, output_final_state)
o = o[..., :seq_len, :]
return o, final_state

View File

@@ -0,0 +1,138 @@
import triton
import triton.language as tl
inv_ln2 = 1.44269504
@triton.jit
def fwd_decay_cumsum(
g,
g_o,
s_qk_h,
s_qk_t,
s_qk_d,
B,
H,
T,
scale,
BT: tl.constexpr,
BK: tl.constexpr,
DK: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
cum_decay = tl.zeros([BK], dtype=tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
for i in range(BT):
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
cum_decay += _g * inv_ln2
tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)
p_g += DK
p_go += DK
@triton.jit
def prepare_qg_kg(
q,
k,
g,
qg,
kg,
s_qk_h,
s_qk_t,
s_qk_d,
B,
H,
T,
scale,
BT: tl.constexpr,
BK: tl.constexpr,
DK: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
mask = (i_k * BK + tl.arange(0, BK)) < DK
last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))
for i in range(BT):
_q = tl.load(p_q, mask=mask, other=0)
_k = tl.load(p_k, mask=mask, other=0)
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
_q *= tl.math.exp2(_g) * scale
_k *= tl.math.exp2(last_decay - _g)
tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)
tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)
p_q += DK
p_g += DK
p_k += DK
p_kg += DK
p_qg += DK
@triton.jit
def bwd_decay_global_cumsum(
dq_inner,
dq_inter,
dk_inner,
dk_inter,
q, k, g, dg,
s_qk_h,
s_qk_t,
s_qk_d,
B,
H,
T,
scale,
BT: tl.constexpr,
BK: tl.constexpr,
DK: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
mask = (i_k * BK + tl.arange(0, BK)) < DK
last_g = tl.zeros([BK], dtype=tl.float32)
for j in range(BT-1, -1, -1):
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
if j == (BT-1):
last_g = _g
_dq1 = tl.load(p_dq_inner, mask=mask, other=0)
_dq2 = tl.load(p_dq_inter, mask=mask, other=0)
_dq2 *= tl.math.exp2(_g)
_dq = _dq1 + _dq2
tl.store(p_dq_inter, _dq, mask=mask)
_dk1 = tl.load(p_dk_inner, mask=mask, other=0)
_dk2 = tl.load(p_dk_inter, mask=mask, other=0)
_dk2 *= tl.math.exp2(last_g - _g)
_dk = _dk1 + _dk2
tl.store(p_dk_inter, _dk, mask=mask)
_q = tl.load(p_q, mask=mask, other=0)
_k = tl.load(p_k, mask=mask, other=0)
_dg = _dq * _q - _dk * _k
cum_grad_dg += _dg
tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
p_g -= DK
p_k -= DK
p_q -= DK
p_dq_inner -= DK
p_dk_inner -= DK
p_dq_inter -= DK
p_dk_inter -= DK
p_dg -= DK

116
finetune/lora/v6/fla/ops/gla/naive.py vendored Normal file
View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
from fla.ops.gla.recurrent_fuse import fused_recurrent_gla
def ceildiv(a, b):
return -(a // -b)
def naive_recurrent_gla(
q,
k,
v,
gk,
initial_state=None,
output_final_state=False,
causal=True
):
orig_dtype = q.dtype
q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk))
batch_size, n_heads, seq_len, d_head_k = q.shape
_, _, _, d_head_v = v.shape
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
o = torch.zeros_like(v)
scale = d_head_k ** -0.5
if initial_state is not None:
h += initial_state
for i in range(seq_len):
q_i = q[:, :, i, :] * scale
k_i = k[:, :, i]
v_i = v[:, :, i, :]
gk_i = gk[:, :, i].exp()
kv_i = k_i[..., None] * v_i[..., None, :]
h = h * gk_i[..., None] + kv_i
o_i = (q_i[..., None] * h).sum(-2)
o[:, :, i] = o_i
if causal:
return o.to(orig_dtype), h
else:
o_reverse = torch.zeros_like(v)
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
for i in range(seq_len-1, -1, -1):
q_i = q[:, :, i, :] * scale
k_i = k[:, :, i]
v_i = v[:, :, i, :]
gk_i = gk[:, :, i].exp()
kv_i = k_i[..., None] * v_i[..., None, :]
h = h * gk_i[..., None] + kv_i
o_i = (q_i[..., None] * h).sum(-2)
o_reverse[:, :, i] = o_i
return o, o_reverse
if __name__ == "__main__":
B = 4
H = 4
L = 512
D = 128
dtype = torch.float32
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True)
g = F.logsigmoid(torch.rand(B, H, L, D)).cuda(
).clamp_min(-1).to(torch.float32).requires_grad_(True)
do = torch.rand_like(v).cuda()
do2 = torch.rand_like(v).cuda()
intial_state = torch.rand(B, H, D, D).cuda()
ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False)
ref.backward(do, retain_graph=True)
ref_rev.backward(do2, retain_graph=True)
ref_dq, q.grad = q.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dv, v.grad = v.grad.clone(), None
ref_dg, g.grad = g.grad.clone(), None
tri, tri_rev = fused_recurrent_gla(
q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False)
tri.backward(do, retain_graph=True)
tri_rev.backward(do2, retain_graph=True)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
tri_dg, g.grad = g.grad.clone(), None
assert ref.allclose(tri, 0, 1e-5), breakpoint()
assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint()
assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
# tri = fused_chunk_gla(q, k, v, g)
# tri.backward(do, retain_graph=True)
# tri_dq, q.grad = q.grad.clone(), None
# tri_dk, k.grad = k.grad.clone(), None
# tri_dv, v.grad = v.grad.clone(), None
# tri_dg, g.grad = g.grad.clone(), None
# assert ref.allclose(tri, 0, 1e-5), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
# assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
# breakpoint()
print("Pass")

View File

@@ -0,0 +1,404 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.jit
def fused_recurrent_gla_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_K]
v, # value [B, H, L, D_head_V]
gk, # log gate [B, H, L, D_head_K]
gv, # log gate [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
final_state, # final hidden state [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
USE_GK: tl.constexpr, # whether to use gk
USE_GV: tl.constexpr, # whether to use gv
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
p_k = k + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
p_v = v + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
h = tl.zeros([BV, BK], dtype=tl.float32)
mask_kv = mask_bk[None, :] & mask_bv[:, None]
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
if USE_GK:
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
h = h * _gk[None, :]
if USE_GV:
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
h = h * _gv[:, None]
h += _k[None, :] * _v[:, None]
_o = h * _q[None, :]
_o = tl.sum(_o, axis=1)
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += -DK if REVERSE else DK
p_k += -DK if REVERSE else DK
p_o += -DV if REVERSE else DV
p_v += -DV if REVERSE else DV
if USE_GK:
p_gk += -DK if REVERSE else DK
if USE_GV:
p_gv += -DV if REVERSE else DV
if STORE_FINAL_STATE:
p_final_s = final_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_gla_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
gk, # log gate [B, H, L, D_head_K] \alpha
gv, # log gate [B, H, L, D_head_V] \bete
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
USE_GK: tl.constexpr, # whether to use gk
USE_GV: tl.constexpr, # whether to use gv
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
p_k = k + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
p_v = v + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
p_do = do + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
mask_kv = mask_bk[:, None] & mask_bv[None, :]
h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[:, None]) * \
DV + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for i in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
if USE_GK:
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
h = h * _gk[:, None]
if USE_GV:
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
h = h * _gv[None, :]
h += _k[:, None] * _v[None, :]
_d_q = h * _do[None, :]
d_q = tl.sum(_d_q, axis=1) * scale
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
p_k += -DK if REVERSE else DK
p_v += -DV if REVERSE else DV
p_q += -DK if REVERSE else DK
p_do += -DV if REVERSE else DV
p_dq += -DK if REVERSE else DK
if USE_GK:
p_gk += -DK if REVERSE else DK
if USE_GV:
p_gv += -DV if REVERSE else DV
# sync threads
tl.debug_barrier()
p_q = q + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
p_k = k + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
p_do = do + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
p_v = v + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
if USE_GK:
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
if USE_GV:
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
d_h = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
d_h += _q[:, None] * _do[None, :]
d_k = tl.sum(d_h * _v[None, :], axis=1)
d_v = tl.sum(d_h * _k[:, None], axis=0)
if USE_GK:
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
d_h *= _gk[:, None]
if USE_GV:
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
d_h *= _gv[None, :]
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
p_do += DV if REVERSE else -DV
p_q += DK if REVERSE else -DK
p_k += DK if REVERSE else -DK
p_v += DV if REVERSE else -DV
p_dk += DK if REVERSE else -DK
p_dv += DV if REVERSE else -DV
if USE_GK:
p_gk += DK if REVERSE else -DK
if USE_GV:
p_gv += DV if REVERSE else -DV
class FusedRecurrentGLAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
# default scale
if scale is None:
scale = d_head_qk ** -0.5
if gk is not None:
gk = gk.float().exp()
if gv is not None:
gv = gv.float().exp()
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
o = q.new_empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=torch.float32)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
else:
final_state = None
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_gla_fwd_kernel[grid](
q, k, v, gk, gv, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None,
USE_GK=gk is not None,
USE_GV=gv is not None,
REVERSE=reverse,
num_warps=num_warps,
num_stages=num_stages
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)
ctx.scale = scale
ctx.reverse = reverse
# we do not need the gradient of the final state from the next chunk
# similiar to Trunctated BPTT
if final_state is not None:
final_state = final_state.detach()
return o.to(q.dtype), final_state
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, d_final_state=None):
q, k, v, gk, gv, initial_state, o = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = ctx.scale
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
dq = q.new_empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=torch.float32)
dk = q.new_empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=torch.float32)
dv = q.new_empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=torch.float32)
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_gla_bwd_kernel[grid](
q, k, v, gk, gv, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
REVERSE=ctx.reverse,
USE_GK=gk is not None,
USE_GV=gv is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
if gk is not None:
_dgk = dq * q.float() - dk * k.float()
if ctx.reverse:
dgk = _dgk.cumsum(-2)
else:
_dgk_cumsum = _dgk.cumsum(-2)
dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum
else:
dgk = None
if gv is not None:
_dgv = do.float() * o.float() - dv * v.float()
if ctx.reverse:
dgv = _dgv.cumsum(-2)
else:
_dgv_cumsum = _dgv.cumsum(-2)
dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum
else:
dgv = None
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None
# if scale is None, use d_head_qk ** -0.5 by default. Otherwise specify the scale yourself. e.g. scale = 1.0
def fused_recurrent_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
gk: torch.Tensor = None,
gv: torch.Tensor = None,
scale: int = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
causal: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
if scale == -1:
scale = q.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
if causal:
o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)
return o, final_state
else:
# do not support initial_state yet. looks very strange for bidirectional modeling
assert initial_state is None
assert output_final_state is False
o, final_state = FusedRecurrentGLAFunction.apply(
q, k, v, gk, gv, scale, initial_state, output_final_state, False)
o_reversed, final_state = FusedRecurrentGLAFunction.apply(
q, k, v, gk, gv, scale, initial_state, output_final_state, True)
return [o, o_reversed]

View File

@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_hgrn
from .recurrent_fuse import fused_recurrent_hgrn
__all__ = [
'chunk_hgrn',
'fused_recurrent_hgrn'
]

373
finetune/lora/v6/fla/ops/hgrn/chunk.py vendored Normal file
View File

@@ -0,0 +1,373 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Yu Zhang, Songlin Yang
# this function implements the chunkwise form of HGRN, inspired by
# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html)
# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan
# from tests on H800, with B, H, D = 16, 4, 128, we see that the chunk can be greatly faster than the recurrent:
#
# Performance:
# seq_len chunk recurrent chunk_bwd recurrent_bwd
# 0 128.0 0.039360 0.061056 0.312160 0.205008
# 1 256.0 0.045824 0.123712 0.308784 0.297696
# 2 512.0 0.058688 0.241952 0.310720 0.626528
# 3 1024.0 0.088288 0.476992 0.313184 1.333152
# 4 2048.0 0.169472 0.943264 0.452464 2.724864
# 5 4096.0 0.329920 1.886144 0.881600 5.551520
# 6 8192.0 0.647872 3.755040 1.740496 11.117184
# 7 16384.0 1.272064 7.520576 3.446608 22.362528
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.jit
def chunk_hgrn_fwd_kernel_h(
x,
g,
gc,
o,
h0,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr
):
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
p_x = x + i_bh * T * D + i_t * BT * D + o_d
p_g = g + i_bh * T * D + i_t * BT * D + o_d
p_gc = gc + i_bh * T * D + i_t * BT * D + o_d
p_o = o + i_bh * T * D + i_t * BT * D + o_d
b_h = tl.zeros([BD], dtype=tl.float32)
b_gc = tl.zeros([BD], dtype=tl.float32)
if USE_INITIAL_STATE:
if i_t == 0:
b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
for i in range(0, BT):
mask_t = mask & ((i_t * BT + i) < T)
b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)
b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)
b_h = tl.exp(b_g) * b_h + b_x
b_gc = b_gc + b_g
tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)
p_x += D
p_g += D
p_gc += D
p_o += D
@triton.jit
def chunk_hgrn_fwd_kernel_o(
gc,
o,
s_h,
s_t,
s_d,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr
):
i_d, i_bh = tl.program_id(0), tl.program_id(1)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
for i_t in range(1, tl.cdiv(T, BT)):
p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
# [BD,]
b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)
# [BT, BD]
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
b_o = b_o + tl.exp(b_gc) * b_h0[None, :]
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.jit
def chunk_hgrn_bwd_kernel_h(
g,
gc,
dx,
do,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr
):
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
BC = min(BT, T - i_t * BT)
NT = tl.num_programs(1)
p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d
p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d
p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d
p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d
if i_t == NT - 1:
b_gc = tl.zeros([BD], dtype=tl.float32)
else:
b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)
b_dh = tl.zeros([BD], dtype=tl.float32)
for _ in range(BC - 1, -1, -1):
tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
b_gc = b_gc + b_g
b_dh = b_dh + b_do
b_dx = b_dh
b_dh = b_dh * tl.exp(b_g)
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
p_g -= D
p_gc -= D
p_dx -= D
p_do -= D
@triton.jit
def chunk_hgrn_bwd_kernel_o(
g,
gc,
o,
dx,
dg,
s_h,
s_t,
s_d,
T: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr
):
i_d, i_bh = tl.program_id(0), tl.program_id(1)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))
p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
# [BD,]
mask_t = mask & ((i_t + 1) * BT < T)
b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)
# [BT, BD]
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)
b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)
b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]
b_dg = b_o * b_dx * tl.exp(b_g)
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
class ChunkHGRNFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, x, g, initial_state=None, output_final_state=False):
B, H, T, D = x.shape
BT, BD = 128, min(64, triton.next_power_of_2(D))
num_warps = 8 if BD == 64 else 4
gc = torch.empty_like(g, dtype=torch.float)
o = torch.empty_like(x, dtype=torch.float)
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
chunk_hgrn_fwd_kernel_h[grid](
x, g, gc, o, initial_state,
T, D,
BT=BT,
USE_INITIAL_STATE=initial_state is not None
)
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
chunk_hgrn_fwd_kernel_o[grid](
gc, o,
o.stride(1), o.stride(2), o.stride(3),
T, D,
BT=BT, BD=BD,
num_warps=num_warps
)
final_state = None
if output_final_state:
final_state = o[:, :, -1].clone()
o = o.to(x.dtype)
ctx.save_for_backward(g, o, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, dht=None):
g, o, initial_state = ctx.saved_tensors
B, H, T, D = do.shape
BT, BD = 128, min(64, triton.next_power_of_2(D))
num_warps = 8 if BD == 64 else 4
gc = torch.empty_like(g, dtype=torch.float)
dx = torch.empty_like(o)
dg = torch.empty_like(g)
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
chunk_hgrn_bwd_kernel_h[grid](
g, gc, dx, do,
T, D,
BT=BT
)
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
chunk_hgrn_bwd_kernel_o[grid](
g, gc, o, dx, dg,
o.stride(1), o.stride(2), o.stride(3),
T, D,
BT=BT, BD=BD,
num_warps=num_warps
)
if initial_state is not None:
dg[:, :, 0] = initial_state * dx[:, :, 0] * g[:, :, 0].exp()
return dx, dg, None, None
def chunk_hgrn(
x: torch.Tensor,
g: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)
return o, final_state
if __name__ == '__main__':
import torch.nn.functional as F
from fla.ops.hgrn.naive import naive_recurrent_hgrn
from fla.ops.hgrn.recurrent_fuse import fused_recurrent_hgrn
B, H, T, D = 8, 4, 512, 128
dtype = torch.bfloat16
torch.manual_seed(42)
# [batch_size, n_heads, seq_len, d_head]
x = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
g = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
x, g = (1 - g.sigmoid()) * x, F.logsigmoid(g)
print(f'x:\t{float(x.min()):>10.6f}\t{float(x.max()):>10.6f}')
print(f'g:\t{float(g.min()):>10.6f}\t{float(g.max()):>10.6f}')
x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
print(f"DTYPE:\t{x.dtype}")
do = torch.randn_like(x)
h0 = torch.randn_like(x[:, :, 0])
ref, ref_ht = naive_recurrent_hgrn(x, g, h0, output_final_state=True)
ref.backward(do)
ref_dx, x.grad = x.grad.clone(), None
ref_dg, g.grad = g.grad.clone(), None
tri, tri_ht = fused_recurrent_hgrn(x, g, h0, output_final_state=True)
tri.backward(do)
tri_dx, x.grad = x.grad.clone(), None
tri_dg, g.grad = g.grad.clone(), None
print(" \t DIFF\t MAX")
print(' o\t', f"{float((ref - tri).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
print('ht\t', f"{float((ref_ht[0] - tri_ht[0]).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
print('dx\t', f"{float((ref_dx - tri_dx).abs().max()):>10.6f}\t{float(ref_dx.max()):>10.6f}")
print('dg\t', f"{float((ref_dg - tri_dg).abs().max()):>10.6f}\t{float(ref_dg.max()):>10.6f}")
print('Done!')
@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['seq_len'],
# different possible values for `x_name`
x_vals=[128 * 2 ** i for i in range(0, 8)],
# argument name whose value corresponds to a different line in the plot
line_arg='provider',
# possible values for `line_arg``
line_vals=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
# label name for the lines
line_names=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
# line styles
styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
ylabel="Execution Time (ms)", # label name for the y-axis
# name for the plot. Used also as a file name for saving the plot.
plot_name="Performance",
args={},
)
)
def benchmark(seq_len, provider):
dtype = torch.bfloat16
B, H, D = 16, 4, 128
x = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda')
g = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda').sigmoid()
x = (1 - g) * x
x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
do = torch.randn_like(x, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
results = 0, 0, 0
if provider == 'chunk':
results = triton.testing.do_bench(lambda: chunk_hgrn(x, g), quantiles=quantiles)
if provider == 'recurrent':
results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g), quantiles=quantiles)
if provider == 'chunk_bwd':
results = triton.testing.do_bench(lambda: chunk_hgrn(x, g)[0].backward(do), quantiles=quantiles)
if provider == 'recurrent_bwd':
results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g)[0].backward(do), quantiles=quantiles)
return results
benchmark.run(print_data=True)

31
finetune/lora/v6/fla/ops/hgrn/naive.py vendored Normal file
View File

@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
from typing import Optional
import torch
def naive_recurrent_hgrn(
x: torch.Tensor,
g: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: Optional[bool] = False
) -> torch.Tensor:
dtype = x.dtype
x, g = map(lambda i: i.float(), (x, g))
B, H, T, D = x.shape
h = torch.zeros(B, H, D, dtype=torch.float, device=x.device)
o = torch.zeros_like(x)
final_state = None
if initial_state is not None:
h += initial_state.detach()
for i in range(T):
h = g[:, :, i].exp() * h + x[:, :, i]
o[:, :, i] = h
if output_final_state:
final_state = h
return o.to(dtype), final_state

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.jit
def fused_recurrent_hgrn_fwd_kernel(
x,
g,
o,
h0,
ht,
T: tl.constexpr,
D: tl.constexpr,
BD: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_d, i_bh = tl.program_id(0), tl.program_id(1)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
p_x = x + i_bh * T * D + o_d
p_g = g + i_bh * T * D + o_d
p_o = o + i_bh * T * D + o_d
b_h = tl.zeros([BD], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = h0 + i_bh * D + o_d
b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
for _ in range(0, T):
b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
b_h = tl.exp(b_g) * b_h + b_x
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
p_x += D
p_g += D
p_o += D
if STORE_FINAL_STATE:
p_ht = ht + i_bh * D + o_d
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.jit
def fused_recurrent_hgrn_bwd_kernel(
g,
o,
dx,
dg,
do,
h0,
T: tl.constexpr,
D: tl.constexpr,
BD: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr
):
i_d, i_bh = tl.program_id(0), tl.program_id(1)
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
p_g = g + (i_bh * T + T - 1) * D + o_d
p_o = o + (i_bh * T + T - 2) * D + o_d
p_dx = dx + (i_bh * T + T - 1) * D + o_d
p_dg = dg + (i_bh * T + T - 1) * D + o_d
p_do = do + (i_bh * T + T - 1) * D + o_d
b_dh = tl.zeros([BD], dtype=tl.float32)
for i in range(T - 1, -1, -1):
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
if i > 0:
b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
elif USE_INITIAL_STATE:
b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
else:
b_o = tl.zeros([BD], dtype=tl.float32)
b_dh = b_dh + b_do
b_dx = b_dh
b_dh = b_dh * tl.exp(b_g)
b_dg = b_dh * b_o
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
p_g -= D
p_o -= D
p_dx -= D
p_dg -= D
p_do -= D
class FusedRecurrentHGRNFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, x, g, initial_state=None, output_final_state=False):
B, H, T, D = x.shape
final_state = None
if output_final_state:
final_state = x.new_empty(B, H, D)
o = torch.empty_like(x)
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
fused_recurrent_hgrn_fwd_kernel[grid](
x, g, o, initial_state, final_state,
T, D,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None
)
ctx.save_for_backward(g, o, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, dht=None):
g, o, initial_state = ctx.saved_tensors
B, H, T, D = do.shape
dx = torch.empty_like(o)
dg = torch.empty_like(g)
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
fused_recurrent_hgrn_bwd_kernel[grid](
g, o, dx, dg, do, initial_state,
T, D,
USE_INITIAL_STATE=initial_state is not None,
)
return dx, dg, None, None
def fused_recurrent_hgrn(
x: torch.Tensor,
g: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)
return o, final_state

View File

@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_linear_attn
from .chunk_fuse import fused_chunk_linear_attn
from .recurrent_fuse import fused_recurrent_linear_attn
__all__ = [
'chunk_linear_attn',
'fused_chunk_linear_attn',
'fused_recurrent_linear_attn'
]

View File

@@ -0,0 +1,359 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
@torch.jit.script
def normalize_output(q, k, o):
k = k.transpose(-2, -1)
k = k.cumsum(-1)
k = k.transpose(-2, -1)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-5)
@triton.jit
def chunk_linear_attn_fwd_kernel_h(
k,
v,
h,
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BK, BV]
b_h += tl.dot(b_k, b_v, allow_tf32=False)
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_linear_attn_fwd_kernel_o(
q,
k,
v,
h,
o,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_s = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_o += tl.dot(b_q, b_h, allow_tf32=False)
b_s += tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_linear_attn_bwd_kernel_dh(
q,
do,
dh,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, V]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BV]
b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
@triton.jit
def chunk_linear_attn_bwd_kernel_dqkv(
q,
k,
v,
h,
do,
dh,
dq,
dk,
dv,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
o_i = tl.arange(0, BT)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BT]
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
# [BT, BK]
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
class ChunkLinearAttentionFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@contiguous
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
ctx.scale = scale
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
h = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_linear_attn_fwd_kernel_h[grid](
k, v, h, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NV, NT, B * H)
o = torch.empty_like(v)
chunk_linear_attn_fwd_kernel_o[grid](
q, k, v, h, o,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
ctx.save_for_backward(q, k, v, h)
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_ht=None):
q, k, v, h = ctx.saved_tensors
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
scale = ctx.scale
dh = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_linear_attn_bwd_kernel_dh[grid](
q, do, dh,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NK, NT, B * H)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = v.new_empty(NK, *v.shape)
num_stages = 1
num_warps = 4 if BK == 64 else 2
chunk_linear_attn_bwd_kernel_dqkv[grid](
q, k, v, h, do, dh, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
dv = dv.sum(0)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
def chunk_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
if scale == -1:
scale = q.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
if normalize:
o = normalize_output(q * scale, k, o)
return o, final_state

View File

@@ -0,0 +1,326 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@torch.jit.script
def normalize_output(q, k, o):
k = k.transpose(-2, -1)
k = k.cumsum(-1)
k = k.transpose(-2, -1)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-5)
@triton.jit
def fused_chunk_linear_attn_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
# [BT, BT]
m_s = o_i[:, None] >= o_i[None, :]
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i in range(0, tl.cdiv(T, BT)):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_k.dtype)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
# [BT, BV]
b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
if CHECK and i == 0:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
else:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
if STORE_FINAL_STATE:
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_chunk_linear_attn_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
# [BV, BK]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [DV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, DV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BT, BT]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0)
# [BT, DK]
b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
# [DV, DK]
if CHECK and i == 0:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
else:
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
b_dq *= scale
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# sync threads
b_h = None
tl.debug_barrier()
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
m_s = o_i[:, None] <= o_i[None, :]
for i in range(1, tl.cdiv(T, BT) + 1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
# [DK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, DV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# b_dd = (b_do]).to(b_do.dtype)
# [BT, BT]
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
# [BT, BT]
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
# [BT, DK]
b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
# [BT, DV]
b_dv = tl.dot(b_s, b_do, allow_tf32=False)
if CHECK and i == 1:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
else:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
class FusedChunkLinearAttentionFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
ctx.scale = scale
BT = 64
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 4
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
else:
final_state = None
# the bug still exists even for Triton 2.2 on H100 GPUs
# so we always enable initial checks
CHECK = True
if version.parse(triton.__version__) < version.parse('2.2.0'):
import warnings
warnings.warn(
"Triton<2.2.0 detected for running this kernel, "
"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
"that lead to significant precision loss. "
"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
)
CHECK = True
grid = (NV, NK, batch_size * n_heads)
fused_chunk_linear_attn_fwd_kernel[grid](
q, k, v, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
CHECK=CHECK,
num_warps=num_warps,
num_stages=num_stages
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, initial_state)
ctx.CHECK = CHECK
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_final_state=None):
q, k, v, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = ctx.scale
BT = 64
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 4
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_linear_attn_bwd_kernel[grid](
q, k, v, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
CHECK=ctx.CHECK,
num_warps=num_warps,
num_stages=num_stages
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
def fused_chunk_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
if scale == -1:
scale = q.shape[-1] ** -0.5
o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
if normalize:
o = normalize_output(q * scale, k, o)
return o, final_state

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
def torch_chunk_linear_attn(q, k, v, chunk_size=64):
q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] **-0.5)
k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
kv = k.transpose(-1, -2) @ v
kv = kv.cumsum(2)
kv = torch.cat([
torch.zeros_like(kv[:, :, :1]),
kv[:, :, :-1]
], dim=2)
inter = q @ kv
intra = ((q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)) @ v
o = inter + intra
return rearrange(o, 'b h n c d -> b h (n c) d')

View File

@@ -0,0 +1,284 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@torch.jit.script
def normalize_output(q, k, o):
k = k.transpose(-2, -1)
k = k.cumsum(-1)
k = k.transpose(-2, -1)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-5)
@triton.jit
def fused_recurrent_linear_attn_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
initial_state,
final_state, # final hidden state [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
mask_kv = mask_bk[None, :] & mask_bv[:, None]
h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
h += _k[None, :] * _v[:, None]
_o = h * _q[None, :]
_o = tl.sum(_o, axis=1)
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += DK
p_k += DK
p_o += DV
p_v += DV
if STORE_FINAL_STATE:
p_final_s = final_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_linear_attn_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
mask_kv = mask_bk[:, None] & mask_bv[None, :]
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[:, None]) * \
DV + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for i in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
h += _k[:, None] * _v[None, :]
_d_q = h * _do[None, :]
d_q = tl.sum(_d_q, axis=1) * scale
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
p_k += DK
p_do += DV
p_v += DV
p_dq += DK
# sync threads
tl.debug_barrier()
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
BK + tl.arange(0, BK) + (T - 1) * DK
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
BV + tl.arange(0, BV) + (T - 1) * DV
d_h = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
d_h += _q[:, None] * _do[None, :]
d_k = tl.sum(d_h * _v[None, :], axis=1)
d_v = tl.sum(d_h * _k[:, None], axis=0)
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
p_do -= DV
p_q -= DK
p_k -= DK
p_v -= DV
p_dk -= DK
p_dv -= DV
class FusedRecurrentLinearAttentionFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, q, k, v, initial_state=None, output_final_state=False):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
else:
final_state = None
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_linear_attn_fwd_kernel[grid](
q, k, v, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, d_final_state=None):
q, k, v, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_linear_attn_bwd_kernel[grid](
q, k, v, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
return dq, dk, dv, None, None
def fused_recurrent_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedRecurrentLinearAttentionFunction.apply(
q, k, v, initial_state, output_final_state)
if normalize:
o = normalize_output(q, k, o)
return o, final_state

View File

@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
from .parallel import parallel_rebased
__all__ = [
'parallel_rebased'
]

View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
from fla.ops.rebased.parallel import parallel_rebased
def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True):
if use_scale:
q = q * (q.shape[-1] ** -0.5)
attn = q @ k.transpose(-2, -1)
attn = (attn ** 2)
attn.masked_fill_(~torch.tril(torch.ones(
q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
o = attn @ v
if use_norm:
z = attn.sum(-1)
return o / (z[..., None] + 1e-6)
else:
return o
if __name__ == "__main__":
B = 4
H = 4
L = 128
# D = 15
dtype = torch.float32
q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True)
do = torch.randn_like(v).cuda()
ref = naive_parallel_rebased(q, k, v, True, True)
ref.backward(do, retain_graph=True)
ref_dq, q.grad = q.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dv, v.grad = v.grad.clone(), None
# tri = naive_chunk_based(q, k, v)
# tri.backward(do, retain_graph=True)
# tri_dq, q.grad = q.grad.clone(), None
# tri_dk, k.grad = k.grad.clone(), None
# tri_dv, v.grad = v.grad.clone(), None
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
tri = parallel_rebased(q, k, v, 1e-6, True, True)
tri.backward(do, retain_graph=True)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
print((ref-tri).abs().max())
print((ref_dq-tri_dq).abs().max())
print((ref_dk-tri_dk).abs().max())
print((ref_dv-tri_dv).abs().max())
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
# tri = parallel_based(q, k, v, True, True)
# tri.backward(do, retain_graph=True)
# tri_dq, q.grad = q.grad.clone(), None
# tri_dk, k.grad = k.grad.clone(), None
# tri_dv, v.grad = v.grad.clone(), None
# print((ref-tri).abs().max())
# print((ref_dq-tri_dq).abs().max())
# print((ref_dk-tri_dk).abs().max())
# print((ref_dv-tri_dv).abs().max())
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()

View File

@@ -0,0 +1,387 @@
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models
# https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
@triton.jit
def parallel_rebased_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
z, # normalizer [B, H, L]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q
BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
):
# i_c: chunk index. used for sequence parallelism
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
NV = tl.cdiv(DV, BV)
i_k = i_kv // (NV)
i_v = i_kv % (NV)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))
# [BQ, BD] block Q, in the shared memory throughout the whole kernel
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
b_o = tl.zeros([BTL, BV], dtype=tl.float32)
b_z = tl.zeros([BTL], dtype=tl.float32)
# Q block and K block have no overlap
# no need for mask, thereby saving flops
for _ in range(0, i_c * BTL, BTS):
# [BK, BTS]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BTS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
b_s = tl.dot(b_q, (b_k), allow_tf32=False)
b_s = b_s * b_s
b_z += tl.sum(b_s, axis=1)
# [BQ, BD]
b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
p_k = tl.advance(p_k, (0, BTS))
p_v = tl.advance(p_v, (BTS, 0))
# # rescale interchunk output
tl.debug_barrier()
o_q = tl.arange(0, BTL)
# # sync threads, easy for compiler to optimize
# tl.debug_barrier()
o_k = tl.arange(0, BTS)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
# Q block and K block have overlap. masks required
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
# [BK, BTS]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BTS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
m_s = o_q[:, None] >= o_k[None, :]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
b_s = b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_z += tl.sum(b_s, axis=1)
# [BTL, BV]
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
p_k = tl.advance(p_k, (0, BTS))
p_v = tl.advance(p_v, (BTS, 0))
o_k += BTS
p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_z, b_z.to(p_z.dtype.element_ty),
mask=((i_c * BTL + tl.arange(0, BTL)) < T))
@triton.jit
def _parallel_rebased_bwd_dq(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_q = (b_q * scale).to(b_q.dtype)
b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
(s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))
p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
for _ in range(0, i_c * BTL, BTS):
# [BTS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BTS]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[:, None]
else:
b_ds = b_ds
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
# [BQ, BD]
b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)
p_k = tl.advance(p_k, (BTS, 0))
p_v = tl.advance(p_v, (0, BTS))
b_dq *= scale
o_q = tl.arange(0, BTL)
o_k = tl.arange(0, BTS)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
(s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
# Q block and K block have overlap. masks required
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
# [BTS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BTS]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
m_s = o_q[:, None] >= o_k[None, :]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[:, None]
else:
b_ds = b_ds
b_ds = tl.where(m_s, b_ds, 0) * scale
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
b_s = tl.where(m_s, b_s, 0)
# [BTL, BK]
b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype),
b_k, allow_tf32=False)
p_k = tl.advance(p_k, (BTS, 0))
p_v = tl.advance(p_v, (0, BTS))
o_k += BTS
p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
return
@triton.jit
def _parallel_rebased_bwd_dkv(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
# compute dk dv
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
(i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(
p_v, boundary_check=(0, 1))
b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
[BTL, BV], dtype=tl.float32)
for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \
scale # [BTL, BTS]
b_s2 = b_s * b_s
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
if i_v == 0:
b_ds += b_dz[None, :] * scale
else:
b_ds = b_ds
b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),
tl.trans(b_q), allow_tf32=False)
tl.debug_barrier()
o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
# [BK, BQ]
m_s = o_k[:, None] <= o_q[None, :]
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
b_s2 = b_s * b_s
b_s = tl.where(m_s, b_s, 0)
b_s2 = tl.where(m_s, b_s2, 0)
b_ds = tl.dot(b_v, b_do, allow_tf32=False)
if i_v == 0:
b_ds += b_dz[None, :]
else:
b_ds = b_ds
b_ds = tl.where(m_s, b_ds, 0) * scale
# [BK, BD]
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),
tl.trans(b_q), allow_tf32=False)
o_q += BTS
p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,
(T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,
(T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
return
@triton.jit
def parallel_rebased_bwd_kernel(
q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
NV = tl.cdiv(DV, BV)
i_k = i_kv // (NV)
i_v = i_kv % (NV)
i_h = i_bh % H
_parallel_rebased_bwd_dq(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV
)
tl.debug_barrier()
_parallel_rebased_bwd_dkv(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV
)
class ParallelBasedFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, scale):
BTL, BTS = 128, 32
assert BTL % BTS == 0
# assert q.shape[-1] % 16 == 0
BK = min(128, triton.next_power_of_2(k.shape[-1]))
BV = min(128, triton.next_power_of_2(v.shape[-1]))
BK, BV = max(BK, 16), max(BV, 16)
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
num_stages = 2
num_warps = 4
NK = triton.cdiv(d_head_qk, BK)
NV = triton.cdiv(d_head_v, BV)
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
assert NK == 1, "will encounter some synchronization issue if not."
o = torch.empty(NK, batch_size, n_heads, seq_len,
d_head_v, device=q.device)
z = torch.empty(NK, batch_size, n_heads, seq_len,
device=q.device)
parallel_rebased_fwd_kernel[grid](
q, k, v, o, z,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
num_warps=num_warps,
num_stages=num_stages
)
ctx.save_for_backward(q, k, v)
ctx.scale = scale
return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, dz):
q, k, v = ctx.saved_tensors
scale = ctx.scale
BTL, BTS = 64, 32
assert BTL % BTS == 0
BK = min(128, triton.next_power_of_2(k.shape[-1]))
BV = min(128, triton.next_power_of_2(v.shape[-1]))
BK, BV = max(BK, 16), max(BV, 16)
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
num_stages = 2
num_warps = 4
NK = triton.cdiv(d_head_qk, BK)
NV = triton.cdiv(d_head_v, BV)
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
assert NK == 1, "will encounter some synchronization issue if not"
dq = torch.empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=q.dtype, device=q.device)
dk = torch.empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=q.dtype, device=q.device)
dv = torch.empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=q.dtype, device=q.device)
parallel_rebased_bwd_kernel[grid](
q, k, v, do, dz, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
num_warps=num_warps,
num_stages=num_stages
)
return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
triton_parallel_based = ParallelBasedFunction.apply
def parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False):
assert q.shape[-1] <= 128, "only support feature dim up to 128"
if use_scale:
scale = q.shape[-1] ** -0.5
else:
scale = 1
o, z = triton_parallel_based(q, k, v, scale)
if return_both:
return o, z
if use_normalize:
o = o / (z[..., None] + eps)
else:
o = o
return o.to(q.dtype)

View File

@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_retention
from .chunk_fuse import fused_chunk_retention
from .parallel import parallel_retention
from .recurrent_fuse import fused_recurrent_retention
__all__ = [
'chunk_retention',
'fused_chunk_retention',
'parallel_retention',
'fused_recurrent_retention'
]

View File

@@ -0,0 +1,364 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
@triton.jit
def chunk_retention_fwd_kernel_h(
k,
v,
h,
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
o_i = tl.arange(0, BT)
d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((BT - o_i - 1) * b_b)
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BK, BV]
if i_t == NT - 1 and (T % BT) != 0:
d_b = tl.math.exp2((T % BT) * b_b)
d_i = tl.math.exp2(((T % BT) - o_i - 1) * b_b)
b_h = d_b * b_h + tl.dot(b_k, (b_v * d_i[:, None]).to(b_k.dtype), allow_tf32=False)
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_retention_fwd_kernel_o(
q,
k,
v,
h,
o,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
o_i = tl.arange(0, BT)
d_i = tl.math.exp2((o_i + 1) * b_b)
m_s = o_i[:, None] >= o_i[None, :]
d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_s = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_o += tl.dot((b_q * d_i[:, None]).to(b_q.dtype), b_h, allow_tf32=False)
b_s += tl.dot(b_q, b_k, allow_tf32=False)
b_s *= d_s
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_retention_bwd_kernel_dh(
q,
do,
dh,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
o_i = tl.arange(0, BT)
d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b)
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, V]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BV]
b_dh = d_b * b_dh + tl.dot(b_q, (b_do * d_i[:, None]).to(b_q.dtype), allow_tf32=False)
@triton.jit
def chunk_retention_bwd_kernel_dqkv(
q,
k,
v,
h,
do,
dh,
dq,
dk,
dv,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
n_bh = tl.num_programs(2)
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
o_i = tl.arange(0, BT)
d_q, d_k = tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)
d_q = (d_q * scale).to(d_q.dtype)
m_s = o_i[:, None] >= o_i[None, :]
d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_s = tl.dot(b_k, b_q, allow_tf32=False) * tl.trans(d_s)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BT]
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * d_k[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BT, BT]
b_ds = (b_ds * d_s).to(b_q.dtype)
# [BT, BK]
b_dq = b_dq * d_q[:, None] + tl.dot(b_ds, b_k, allow_tf32=False)
b_dk = b_dk * d_k[:, None] + tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
class ChunkRetentionFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@contiguous
def forward(ctx, q, k, v, initial_state, output_final_state):
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
scale = K ** -0.5
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
h = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_retention_fwd_kernel_h[grid](
k, v, h, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NV, NT, B * H)
o = torch.empty_like(v)
chunk_retention_fwd_kernel_o[grid](
q, k, v, h, o,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
ctx.save_for_backward(q, k, v, h)
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_ht=None):
q, k, v, h = ctx.saved_tensors
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
scale = K ** -0.5
dh = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_retention_bwd_kernel_dh[grid](
q, do, dh,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NK, NT, B * H)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = v.new_empty(NK, *v.shape)
num_stages = 1
num_warps = 4 if BK == 64 else 2
chunk_retention_bwd_kernel_dqkv[grid](
q, k, v, h, do, dh, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
dv = dv.sum(0)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None
def chunk_retention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = ChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)
return o, final_state

View File

@@ -0,0 +1,334 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from packaging import version
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.jit
def fused_chunk_retention_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
o_i = tl.arange(0, BT)
# decay rate given the head index
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
# d_b: overall decay for the entire chunk
# d_o: cumulative decay from the start of the chunk
# d_h: cumulative decay from the end of the chunk
d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)
# [BT, BT]
m_s = o_i[:, None] >= o_i[None, :]
d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
NT = tl.cdiv(T, BT)
for i in range(0, NT):
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_k.dtype)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s
# [BT, BV]
b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
if CHECK and i == 0:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
else:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
if i == NT - 1 and (T % BT) != 0:
d_b = tl.math.exp2((T % BT) * b_b)
d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)
b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
p_q = tl.advance(p_q, (BT, 0))
p_k = tl.advance(p_k, (0, BT))
p_v = tl.advance(p_v, (BT, 0))
p_o = tl.advance(p_o, (BT, 0))
if STORE_FINAL_STATE:
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_chunk_retention_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr,
CHECK: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
o_i = tl.arange(0, BT)
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)
d_b = tl.math.exp2(BT * b_b)
m_s = o_i[:, None] >= o_i[None, :]
d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale
# [BV, BK]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [DV, BT]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, DV]
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
# [BT, BT]
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
b_ds = (b_ds * d_s).to(b_k.dtype)
# [BT, DK]
b_dq = tl.dot(b_ds, b_k, allow_tf32=False)
# [DV, DK]
if CHECK and i == 0:
b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
else:
b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# sync threads
b_h = None
tl.debug_barrier()
d_s = tl.trans(d_s)
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i in range(1, tl.cdiv(T, BT) + 1):
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
# [DK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BT, DK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, DV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
# [BT, BT]
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
b_ds = (b_ds * d_s).to(b_k.dtype)
# [BT, BT]
b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s
# [BT, DK]
b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
# [BT, DV]
b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
if CHECK and i == 1:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
else:
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
class FusedChunkRetentionFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v, initial_state, output_final_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BT = 64
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 4
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
else:
final_state = None
# the bug still exists even for Triton 2.2 on H100 GPUs
# so we always enable initial checks
CHECK = True
if version.parse(triton.__version__) < version.parse('2.2.0'):
import warnings
warnings.warn(
"Triton<2.2.0 detected for running this kernel, "
"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
"that lead to significant precision loss. "
"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
)
CHECK = True
grid = (NV, NK, batch_size * n_heads)
fused_chunk_retention_fwd_kernel[grid](
q, k, v, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
CHECK=CHECK,
num_warps=num_warps,
num_stages=num_stages
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, initial_state)
ctx.CHECK = CHECK
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_final_state=None):
q, k, v, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BT = 64
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 4
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_chunk_retention_bwd_kernel[grid](
q, k, v, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
CHECK=ctx.CHECK,
num_warps=num_warps,
num_stages=num_stages
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None
def fused_chunk_retention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)
return o, final_state

View File

@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
import torch
def naive_retention(q, k, v):
orig_type = q.dtype
q, k, v = q.float(), k.float(), v.float()
_, n_heads, seq_len, d_head = q.shape
s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2()
n = q.new_tensor(range(seq_len), dtype=torch.float)
n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n)
s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype))
o = torch.einsum('bhqk,bhkd->bhqd', s, v)
return o.to(orig_type)

View File

@@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
@triton.jit
def parallel_retention_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q
BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
):
# i_c: chunk index. used for sequence parallelism
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
NV = tl.cdiv(DV, BV)
i_k = i_kv // (NV)
i_v = i_kv % (NV)
i_h = i_bh % H
# decay rate given the head index
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
# cumulative decay from the end of the chunk
o_k = tl.arange(0, BTS)
d_h = tl.math.exp2((BTS - o_k) * b_b)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))
# [BQ, BD] block Q, in the shared memory throughout the whole kernel
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
b_o = tl.zeros([BTL, BV], dtype=tl.float32)
# Q block and K block have no overlap
# no need for mask, thereby saving flops
for _ in range(0, i_c * BTL, BTS):
# [BK, BTS]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BTS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :]
# [BQ, BD]
b_o = b_o * tl.math.exp2(b_b * BTS)
b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
p_k = tl.advance(p_k, (0, BTS))
p_v = tl.advance(p_v, (BTS, 0))
# # rescale interchunk output
tl.debug_barrier()
o_q = tl.arange(0, BTL)
d_q = tl.math.exp2(tl.arange(0, BTL) * b_b)
b_o *= d_q[:, None]
# # sync threads, easy for compiler to optimize
# tl.debug_barrier()
o_k = tl.arange(0, BTS)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
(s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
# Q block and K block have overlap. masks required
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
# [BK, BTS]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BTS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
m_s = o_q[:, None] >= o_k[None, :]
d_s = tl.where(m_s, tl.math.exp2(
(o_q[:, None] - o_k[None, :]) * b_b), 0)
b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s
# [BTL, BV]
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
p_k = tl.advance(p_k, (0, BTS))
p_v = tl.advance(p_v, (BTS, 0))
o_k += BTS
p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),
(s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def _parallel_retention_bwd_dq(
i_bh, i_c, i_k, i_v, i_h,
k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
(s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))
# decay rate given the head index
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
# overall decay rate for an entire block
d_b = tl.math.exp2(b_b * BTS)
# cumulative decay from the end of the chunk
d_h = tl.math.exp2((BTS - tl.arange(0, BTS)) * b_b)
for _ in range(0, i_c * BTL, BTS):
# [BTS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BTS]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_h[None, :]
# [BQ, BD]
b_dq *= d_b
b_dq += tl.dot(b_ds.to(b_v.dtype), b_k, allow_tf32=False)
p_k = tl.advance(p_k, (BTS, 0))
p_v = tl.advance(p_v, (0, BTS))
b_dq *= tl.math.exp2(tl.arange(0, BTL) * b_b)[:, None] * scale
o_q = tl.arange(0, BTL)
o_k = tl.arange(0, BTS)
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
(s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
# Q block and K block have overlap. masks required
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
# [BTS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BTS]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BTL, BTS]
m_s = o_q[:, None] >= o_k[None, :]
d_s = tl.where(m_s, tl.math.exp2(
(o_q[:, None] - o_k[None, :]) * b_b), 0)
b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_s * scale
# [BTL, BK]
b_dq += tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
p_k = tl.advance(p_k, (BTS, 0))
p_v = tl.advance(p_v, (0, BTS))
o_k += BTS
p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
return
@triton.jit
def _parallel_retention_bwd_dkv(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
# no overlap. no need for mask.
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
# overall decay rate for an entire block
d_b = tl.math.exp2(b_b * BTS)
# compute dk dv
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
(i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(
p_v, boundary_check=(0, 1))
b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
[BTL, BV], dtype=tl.float32)
d_h = tl.math.exp2((BTL - tl.arange(0, BTL)) * b_b)
b_kd = (b_k * d_h[:, None]).to(b_k.dtype)
d_q = tl.math.exp2(tl.arange(0, BTS) * b_b)
for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
b_do = tl.load(p_do, boundary_check=(0, 1)) # [BV, BTS]
b_do = (b_do * d_q[None, :]).to(b_do.dtype)
b_dv *= d_b
b_s = tl.dot(b_kd.to(b_q.dtype), b_q, allow_tf32=False) # [BTL, BTS]
b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
b_dk *= d_b
b_ds = tl.dot(b_v, b_do, allow_tf32=False)
b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
b_dk *= d_h[:, None] * scale
b_dv *= scale
tl.debug_barrier()
o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BQ]
m_s = o_k[:, None] <= o_q[None, :]
d_s = tl.where(m_s, tl.math.exp2(
(-o_k[:, None] + o_q[None, :]) * b_b.to(tl.float32)), 0) * scale
b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s
b_ds = tl.dot(b_v, b_do, allow_tf32=False) * d_s
# [BK, BD]
b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
o_q += BTS
p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,
(T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,
(T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
return
@triton.jit
def parallel_retention_bwd_kernel(
q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale,
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
DK: tl.constexpr, DV: tl.constexpr,
):
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
NV = tl.cdiv(DV, BV)
i_k = i_kv // (NV)
i_v = i_kv % (NV)
i_h = i_bh % H
_parallel_retention_bwd_dq(
i_bh, i_c, i_k, i_v, i_h,
k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV
)
tl.debug_barrier()
_parallel_retention_bwd_dkv(
i_bh, i_c, i_k, i_v, i_h,
q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV
)
class ParallelRetentionFunction(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, q, k, v):
BTL, BTS = 128, 32
assert BTL % BTS == 0
BK = min(128, triton.next_power_of_2(k.shape[-1]))
BV = min(128, triton.next_power_of_2(v.shape[-1]))
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
num_stages = 3 if d_head_qk <= 64 else 2
num_warps = 4
NK = triton.cdiv(d_head_qk, BK)
NV = triton.cdiv(d_head_v, BV)
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
scale = d_head_qk ** -0.5
o = torch.empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=q.dtype, device=q.device)
parallel_retention_fwd_kernel[grid](
q, k, v, o,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
num_warps=num_warps,
num_stages=num_stages
)
ctx.save_for_backward(q, k, v)
return o.sum(0).to(q.dtype)
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do):
q, k, v = ctx.saved_tensors
BTL, BTS = 64, 32
assert BTL % BTS == 0
BK = min(128, triton.next_power_of_2(k.shape[-1]))
BV = min(128, triton.next_power_of_2(v.shape[-1]))
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
num_stages = 3 if d_head_qk <= 64 else 2
num_warps = 4
NK = triton.cdiv(d_head_qk, BK)
NV = triton.cdiv(d_head_v, BV)
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
scale = d_head_qk ** -0.5
dq = torch.empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=q.dtype, device=q.device)
dk = torch.empty(NV, batch_size, n_heads, seq_len,
d_head_qk, dtype=q.dtype, device=q.device)
dv = torch.empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=q.dtype, device=q.device)
parallel_retention_bwd_kernel[grid](
q, k, v, do, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
num_warps=num_warps,
num_stages=num_stages
)
return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype)
parallel_retention = ParallelRetentionFunction.apply

View File

@@ -0,0 +1,281 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
# on-the-fly computation without materializing hidden statets into HBMs
@triton.jit
def fused_recurrent_retention_fwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
o, # output [B, H, L, D_head_V]
initial_state,
final_state, # final hidden state [B, H, D_head_K, D_head_V]
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
):
# indices
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
# decay rate given the head index
b_b = (1 - tl.math.pow(2, -5 - i_h * 1.0))
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
mask_kv = mask_bk[None, :] & mask_bv[:, None]
h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
h = b_b * h + _k[None, :] * _v[:, None]
_o = h * _q[None, :]
_o = tl.sum(_o, axis=1)
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += DK
p_k += DK
p_o += DV
p_v += DV
if STORE_FINAL_STATE:
p_final_s = final_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[None, :]) * \
DV + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_retention_bwd_kernel(
# B: batch_size, H: n_heads, T: seq_len, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, L, D_head_K]
k, # key [B, H, L, D_head_V]
v, # value [B, H, L, D_head_V]
do, # gradient of output [B, H, L, D_head_V]
dq, # gradient of query [NV, B, H, L, D_head_K]
dk, # gradient of key [NV, B, H, L, D_head_K]
dv, # gradient of value [NK, B, H, L, D_head_V]
# initial hidden state initialization [B, H, D_head_K, D_head_V]
initial_state,
s_qk_h, # stride size: L * D_head_K
s_qk_t, # stride size: D_head_K
s_qk_d, # stride size: 1
s_vo_h, # stride size: L * D_head_V
s_vo_t, # stride size: D_head_V
s_vo_d, # stride size: 1
B, # batch_size
H, # n_heads
T, # seq_len
scale, # D_head_K ** -0.5
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
DK: tl.constexpr, # D_head_K
DV: tl.constexpr, # D_head_V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
b_b = 1 - tl.math.pow(2, -5 - i_h * 1.0)
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
mask_bk = i_k * BK + tl.arange(0, BK) < DK
mask_bv = i_v * BV + tl.arange(0, BV) < DV
h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
mask_kv = mask_bk[:, None] & mask_bv[None, :]
p_init_s = initial_state + i_bh * DK * DV + \
(i_k * BK + tl.arange(0, BK)[:, None]) * \
DV + (i_v * BV + tl.arange(0, BV)[None, :])
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
for i in range(0, T):
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
h = b_b * h + _k[:, None] * _v[None, :]
_d_q = h * _do[None, :]
d_q = tl.sum(_d_q, axis=1) * scale
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
p_k += DK
p_do += DV
p_v += DV
p_dq += DK
# sync threads
tl.debug_barrier()
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
BK + tl.arange(0, BK) + (T - 1) * DK
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
BV + tl.arange(0, BV) + (T - 1) * DV
d_h = tl.zeros([BK, BV], dtype=tl.float32)
for _ in range(T):
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
d_h += _q[:, None] * _do[None, :]
d_k = tl.sum(d_h * _v[None, :], axis=1)
d_v = tl.sum(d_h * _k[:, None], axis=0)
d_h *= b_b
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
p_do -= DV
p_q -= DK
p_k -= DK
p_v -= DV
p_dk -= DK
p_dv -= DV
class FusedRecurrentRetentionFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, q, k, v, initial_state=None, output_final_state=False):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
else:
final_state = None
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_retention_fwd_kernel[grid](
q, k, v, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, initial_state)
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, d_final_state=None):
q, k, v, initial_state = ctx.saved_tensors
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
scale = d_head_qk ** -0.5
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_retention_bwd_kernel[grid](
q, k, v, do, dq, dk, dv, initial_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None
)
dq = dq.sum(0)
dk = dk.sum(0)
dv = dv.sum(0)
return dq, dk, dv, None, None
# fused_recurrent_retention = FusedRecurrentRetentionFunction.apply
def fused_recurrent_retention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedRecurrentRetentionFunction.apply(q, k, v, initial_state, output_final_state)
return o, final_state

252
finetune/lora/v6/fla/ops/rotary.py vendored Normal file
View File

@@ -0,0 +1,252 @@
# Copyright (c) 2023, Tri Dao. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py
from typing import Optional, Union
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
# )
@triton.jit
def rotary_kernel(
OUT, # Pointers to matrices
X,
COS,
SIN,
CU_SEQLENS,
SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions
seqlen,
nheads,
rotary_dim,
seqlen_ro,
CACHE_KEY_SEQLEN,
# strides
stride_out_batch,
stride_out_seqlen,
stride_out_nheads,
stride_out_headdim,
stride_x_batch,
stride_x_seqlen,
stride_x_nheads,
stride_x_headdim,
# Meta-parameters
BLOCK_K: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
IS_VARLEN: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
BLOCK_M: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
if not IS_VARLEN:
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
else:
start_idx = tl.load(CU_SEQLENS + pid_batch)
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
if pid_m * BLOCK_M >= seqlen:
return
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
rk = tl.arange(0, BLOCK_K)
rk_half = tl.arange(0, BLOCK_K // 2)
if not INTERLEAVED:
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
X = X + (rm[:, None] * stride_x_seqlen +
rk_half[None, :] * stride_x_headdim)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
cos = tl.load(
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
).to(tl.float32)
sin = tl.load(
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x0 = tl.load(
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
).to(tl.float32)
x1 = tl.load(
X + rotary_dim_half * stride_x_headdim,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
if CONJUGATE:
sin = -sin
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
# write back result
OUT = OUT + (rm[:, None] * stride_out_seqlen +
rk_half[None, :] * stride_out_headdim)
tl.store(OUT, o0, mask=(rm[:, None] < seqlen)
& (rk_half[None, :] < rotary_dim_half))
tl.store(
OUT + rotary_dim_half * stride_out_headdim,
o1,
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
)
else:
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
# Then we do the calculation and use tl.where to pick put the right outputs for the even
# and for the odd indices.
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
rk_repeat = tl.arange(0, BLOCK_K) // 2
X0 = X + (rm[:, None] * stride_x_seqlen +
rk[None, :] * stride_x_headdim)
X1 = X + (rm[:, None] * stride_x_seqlen +
rk_swap[None, :] * stride_x_headdim)
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
cos = tl.load(
COS,
mask=(rm_cs[:, None] < seqlen_ro) & (
rk_repeat[None, :] < rotary_dim_half),
other=1.0,
).to(tl.float32)
sin = tl.load(
SIN,
mask=(rm_cs[:, None] < seqlen_ro) & (
rk_repeat[None, :] < rotary_dim_half),
other=0.0,
).to(tl.float32)
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
tl.float32
)
x1 = tl.load(
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
).to(tl.float32)
if CONJUGATE:
sin = -sin
x0_cos = x0 * cos
x1_sin = x1 * sin
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
OUT = OUT + (rm[:, None] * stride_out_seqlen +
rk[None, :] * stride_out_headdim)
tl.store(OUT, out, mask=(rm[:, None] < seqlen)
& (rk[None, :] < rotary_dim))
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
is_varlen = cu_seqlens is not None
if not is_varlen:
batch, seqlen, nheads, headdim = x.shape
else:
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
total_seqlen, nheads, headdim = x.shape
batch_p_1 = cu_seqlens.shape[0]
batch = batch_p_1 - 1
seqlen = max_seqlen
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
assert (
cos.dtype == sin.dtype
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
assert (
x.dtype == cos.dtype
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
assert seqlen_offsets.shape == (batch,)
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
seqlen_offsets = seqlen_offsets.contiguous()
else:
assert seqlen_offsets + seqlen <= seqlen_ro
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
BLOCK_K = (
32
if rotary_dim <= 32
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
)
def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(x.device.index):
rotary_kernel[grid](
output, # data ptrs
x,
cos,
sin,
cu_seqlens,
seqlen_offsets,
seqlen, # shapes
nheads,
rotary_dim,
seqlen_ro,
# key for triton cache (limit number of compilations)
seqlen // 128,
# batch_strides if not varlen else 0
output.stride(0) if not is_varlen else 0,
output.stride(-3), # seqlen_stride or total_seqlen_stride
output.stride(-2), # nheads_stride
output.stride(-1), # headdim_stride
# batch_strides if not varlen else 0
x.stride(0) if not is_varlen else 0,
x.stride(-3), # seqlen stride or total_seqlen_stride
x.stride(-2), # nheads stride
x.stride(-1), # headdim stride
BLOCK_K,
isinstance(seqlen_offsets, torch.Tensor),
is_varlen,
interleaved,
conjugate,
BLOCK_M,
)
return output

View File

@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
from .recurrent_fuse import fused_recurrent_rwkv4
__all__ = [
'fused_recurrent_rwkv4'
]

View File

@@ -0,0 +1,484 @@
# -*- coding: utf-8 -*-
# adopted from https://github.com/codekansas/rwkv
from typing import Any, cast
import torch
import triton
import triton.language as tl
from torch import Tensor
from torch.autograd.function import Function, FunctionCtx, once_differentiable
def get_block_size_c(chans: int) -> int:
if chans < 32:
return 32
if chans < 64:
return 64
return 128
@triton.jit
def fused_recurrent_rwkv4_forward_kernel(
# W
w_ptr,
w_s_c,
# U
u_ptr,
u_s_c,
# K
k_ptr,
k_s_b,
k_s_t,
k_s_c,
# V
v_ptr,
v_s_b,
v_s_t,
v_s_c,
# State
state_ptr,
state_s_b,
state_s_abe,
state_s_c,
# WKV
wkv_ptr,
wkv_s_b,
wkv_s_t,
wkv_s_c,
# Output state
state_out_ptr,
state_out_s_b,
state_out_s_abe,
state_out_s_t,
state_out_s_c,
# Params
chans,
tsz,
BLOCK_SIZE_C: tl.constexpr,
):
# Parallelize over the batch dimension.
b_idx = tl.program_id(0)
c_idx = tl.program_id(1)
cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
cmask = cs < chans
# Pointers to the batch (and possibly channel) for the input tensors.
k_ptr = k_ptr + b_idx * k_s_b
v_ptr = v_ptr + b_idx * v_s_b
alpha_ptr = state_ptr + b_idx * state_s_b
beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
# Pointers to the batch (and possibly channel) for the output tensors.
wkv_ptr = wkv_ptr + b_idx * wkv_s_b
alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b
beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe
eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe
# Loads parameters.
alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)
u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)
for t in range(tsz):
kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)
vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)
ukt = u + kt
tau = tl.maximum(ukt, eps)
e1a = tl.exp(eps - tau)
e2a = tl.exp(ukt - tau)
wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)
tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)
w_eps = w + eps
eps = tl.maximum(w_eps, kt)
e1b = tl.exp(w_eps - eps)
e2b = tl.exp(kt - eps)
alpha = e1b * alpha + e2b * vt
beta = e1b * beta + e2b
tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)
tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)
tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)
def fused_recurrent_rwkv4_forward(
w: Tensor,
u: Tensor,
k: Tensor,
v: Tensor,
state: Tensor,
) -> tuple[Tensor, Tensor]:
(bsz, tsz, chans) = k.shape
# New tensors to output.
wkvs = k.new_empty(bsz, tsz, chans)
state_out = k.new_empty(bsz, 3, tsz, chans)
# Constants.
block_size_c = get_block_size_c(chans)
def grid(meta: dict[str, Any]) -> tuple[int, ...]:
return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
fused_recurrent_rwkv4_forward_kernel[grid](
# W
w,
w.stride(0),
# U
u,
u.stride(0),
# K
k,
k.stride(0),
k.stride(1),
k.stride(2),
# V
v,
v.stride(0),
v.stride(1),
v.stride(2),
# State
state,
state.stride(0),
state.stride(1),
state.stride(3),
# WKV
wkvs,
wkvs.stride(0),
wkvs.stride(1),
wkvs.stride(2),
# Output state
state_out,
state_out.stride(0),
state_out.stride(1),
state_out.stride(2),
state_out.stride(3),
# Params
chans,
tsz,
BLOCK_SIZE_C=block_size_c,
)
state_out = torch.cat((state, state_out), dim=2)
return wkvs, state_out
@triton.jit
def fused_recurrent_rwkv4_backward_kernel(
# W
w_ptr,
w_s_c,
# U
u_ptr,
u_s_c,
# K
k_ptr,
k_s_b,
k_s_t,
k_s_c,
# V
v_ptr,
v_s_b,
v_s_t,
v_s_c,
# State
state_ptr,
state_s_b,
state_s_abe,
state_s_t,
state_s_c,
# WKV grad
gwkv_ptr,
gwkv_s_b,
gwkv_s_t,
gwkv_s_c,
# Output state grad
gstate_out_ptr,
gstate_out_s_b,
gstate_out_s_abe,
gstate_out_s_c,
# W grad
gw_ptr,
gw_s_c,
# U grad
gu_ptr,
gu_s_c,
# K grad
gk_ptr,
gk_s_b,
gk_s_t,
gk_s_c,
# V grad
gv_ptr,
gv_s_b,
gv_s_t,
gv_s_c,
# State grad
gstate_ptr,
gstate_s_b,
gstate_s_abe,
gstate_s_c,
# Params
tsz,
chans,
BLOCK_SIZE_C: tl.constexpr,
):
# Parallelize over the batch dimension.
b_idx = tl.program_id(0)
c_idx = tl.program_id(1)
cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
cmask = cs < chans
# Pointers to the batch (and possibly channel) for the input tensors.
k_ptr = k_ptr + b_idx * k_s_b
v_ptr = v_ptr + b_idx * v_s_b
alpha_ptr = state_ptr + b_idx * state_s_b
beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
# Pointers to the batch (and possibly channel) for the output tensors.
gk_ptr = gk_ptr + b_idx * gk_s_b
gv_ptr = gv_ptr + b_idx * gv_s_b
# Pointers to gradients which were recieved by the function.
gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b
galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b
gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe
geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe
# Loads parameters.
galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)
u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)
# Gradient accumulators.
gw = tl.zeros_like(w)
gu = tl.zeros_like(u)
alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
for t in range(tsz):
tc = tsz - t - 1
kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)
vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)
alpha_curr = alpha_prev
beta_curr = beta_prev
eps_curr = eps_prev
alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
ukt = u + kt
tau = tl.maximum(ukt, eps_prev)
e1 = tl.exp(eps_prev - tau)
e2 = tl.exp(ukt - tau)
euke = tl.exp(ukt + eps_prev - 2 * tau)
denom = e1 * beta_prev + e2
denom_sq = denom * denom
gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)
# Backpropagates wkv gradients.
guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq
gu += guk
gk = guk
gv = gwkvt * e2 / denom
galpha_wkv = gwkvt * e1 / denom
gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq
geps_wkv_denom = e1 * beta_prev + e2
geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)
e1 = tl.exp(w + eps_prev - eps_curr)
e2 = tl.exp(kt - eps_curr)
# Backpropagates alpha gradients.
galpha_we = galpha * e1 * alpha_prev
gw += galpha_we
gk += galpha * e2 * vt
gv += galpha * e2
geps += galpha * -alpha_curr
# Backpropagates beta gradients.
gbeta_we = gbeta * e1 * beta_prev
gw += gbeta_we
gk += gbeta * e2
geps += gbeta * -beta_curr
# Backpropagates epsilon gradients.
geps_mask = w + eps_prev > kt
geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))
gw += geps_we
gk += tl.where(geps_mask, tl.zeros_like(geps), geps)
# Stores the gradients for k and v.
tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)
tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)
# Computes new gradients for alpha and beta.
galpha = galpha * e1 + galpha_wkv
gbeta = gbeta * e1 + gbeta_wkv
geps = galpha_we + gbeta_we + geps_we + geps_wkv
# Stores final gradients for alpha and beta.
galpha_ptr = gstate_ptr + b_idx * gstate_s_b
gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe
geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe
tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)
tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)
tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)
# Stores final gradients for w and u.
gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)
gw_temp += gw
tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)
gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)
gu_temp += gu
tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)
def fused_recurrent_rwkv4_backward(
w: Tensor,
u: Tensor,
k: Tensor,
v: Tensor,
state: Tensor,
grad_wkv: Tensor,
grad_state: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
bsz, tsz, chans = k.shape
gw = torch.zeros_like(w) # New tensors to output.
gu = torch.zeros_like(u)
gk = torch.empty_like(k)
gv = torch.empty_like(v)
gstate = k.new_empty(bsz, 3, 1, chans)
block_size_c = get_block_size_c(chans) # Constants.
def grid(meta: dict[str, Any]) -> tuple[int, ...]:
return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
fused_recurrent_rwkv4_backward_kernel[grid](
# W
w,
w.stride(0),
# U
u,
u.stride(0),
# K
k,
k.stride(0),
k.stride(1),
k.stride(2),
# V
v,
v.stride(0),
v.stride(1),
v.stride(2),
# State
state,
state.stride(0),
state.stride(1),
state.stride(2),
state.stride(3),
# WKV grad
grad_wkv,
grad_wkv.stride(0),
grad_wkv.stride(1),
grad_wkv.stride(2),
# Output state grad
grad_state,
grad_state.stride(0),
grad_state.stride(1),
grad_state.stride(3),
# W grad
gw,
gw.stride(0),
# U grad
gu,
gu.stride(0),
# K grad
gk,
gk.stride(0),
gk.stride(1),
gk.stride(2),
# V grad
gv,
gv.stride(0),
gv.stride(1),
gv.stride(2),
# State grad
gstate,
gstate.stride(0),
gstate.stride(1),
gstate.stride(3),
# Params
tsz,
chans,
BLOCK_SIZE_C=block_size_c,
)
return gw, gu, gk, gv, gstate
class FusedRecurrentRWKV4Function(Function):
@staticmethod
def forward(
ctx: FunctionCtx,
w: Tensor,
u: Tensor,
k: Tensor,
v: Tensor,
state: Tensor,
) -> tuple[Tensor, Tensor]:
ctx.input_dtype = k.dtype
if (
w.device.type != "cuda"
or u.device.type != "cuda"
or k.device.type != "cuda"
or v.device.type != "cuda"
):
raise ValueError(
"Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices."
)
w = -torch.exp(w.float().contiguous())
if k.dtype == torch.float16:
u = u.float()
k = k.float()
v = v.float()
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state)
ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1])
return wkv, state_out[:, :, -1:]
@staticmethod
@once_differentiable
def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors)
gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate)
return gw, gu, gk, gv, gstate
def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:
return FusedRecurrentRWKV4Function.apply(w, u, k, v, state)

View File

@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_rwkv6
from .recurrent_fuse import fused_recurrent_rwkv6
__all__ = [
'chunk_rwkv6',
'fused_recurrent_rwkv6'
]

921
finetune/lora/v6/fla/ops/rwkv6/chunk.py vendored Normal file
View File

@@ -0,0 +1,921 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from fla.ops.utils import chunk_reversed_cumsum_fwd
from fla.utils import contiguous
@triton.autotune(
configs=[
triton.Config({'BS': 16}, num_warps=2),
triton.Config({'BS': 16}, num_warps=4),
triton.Config({'BS': 16}, num_warps=8),
triton.Config({'BS': 32}, num_warps=2),
triton.Config({'BS': 32}, num_warps=4),
triton.Config({'BS': 32}, num_warps=8),
triton.Config({'BS': 64}, num_warps=2),
triton.Config({'BS': 64}, num_warps=4),
triton.Config({'BS': 64}, num_warps=8),
],
key=['S']
)
@triton.jit
def chunk_rwkv6_fwd_kernel_cum(
s,
o,
o_minus_s,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr
):
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def post_process_grad(
q,
k,
v,
u,
do,
dk,
dq,
du,
scale,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
H,
T: tl.constexpr,
BT: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_h = i_bh % H
# Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V)
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))
p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_u = tl.load(p_u, boundary_check=(0,))
b_vdo = tl.sum(b_v * b_do, axis=1)
b_du = b_vdo[:, None] * b_k * b_q * scale
b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale
b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale
b_dq += tl.load(p_dq, boundary_check=(0, 1))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
b_dk += tl.load(p_dk, boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_rwkv6_fwd_kernel_h(
k,
v,
g,
h,
h0,
ht,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BK, BT]
b_g = tl.load(p_g, boundary_check=(0, 1))
if i_t < NT - 1:
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
else:
b_gn = tl.min(b_g, axis=1)
b_h *= tl.exp(b_gn)[:, None]
b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
b_h += tl.dot(b_k, b_v, allow_tf32=False)
if STORE_FINAL_STATE:
p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_rwkv6_fwd_kernel_intra(
q,
k,
g,
gs,
u,
A,
s_k_h,
s_k_t,
s_k_d,
scale,
H,
T: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
DK: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
i_h = i_bh % H
n_bh = tl.num_programs(2)
o_k = i_k * BK + tl.arange(0, BK)
o_q = i_t * BT + i_i * BC
m_k = o_k < K
if i_i > i_j:
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
# [BK,]
b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0)
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_gs = tl.load(p_gs, boundary_check=(0, 1))
b_qg = (b_q * tl.exp(b_gs - b_gn[None, :]) * scale).to(b_q.dtype)
# [BK, BC]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
# [BC, BC]
b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
elif i_i == i_j:
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
p_q_self = tl.make_block_ptr(q + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_gs = tl.load(p_gs, boundary_check=(0, 1))
o_i = tl.arange(0, BC)
o_g = i_bh * T * K + (i_t * BT + i_j * BC) * K + o_k
o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
p_u = tl.make_block_ptr(u + i_h * DK, (DK,), (1,), (i_k * BK), (BK,), (0,))
b_u = tl.load(p_u, boundary_check=(0,))
for j in range(0, BC):
# [BK,]
b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
b_gk = tl.load(g + o_g + j * K, mask=(m_k & ((i_t * BT + i_j * BC + j) < T)), other=0).to(tl.float32)
# [BC,]
b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_gs - b_gk[None, :]) * scale, 1)
b_A = tl.where(o_i > j, b_A, 0.)
# self
b_q_self = tl.load(p_q_self, boundary_check=(0,)).to(tl.float32)
A_self = tl.sum(b_q_self * b_k * b_u * scale, axis=0)
m_self = tl.arange(0, BC) == j
b_A = tl.where(m_self, A_self[None], b_A)
tl.store(A + o_A + j, b_A.to(A.dtype.element_ty), mask=m_A)
p_k = tl.advance(p_k, (K,))
p_q_self = tl.advance(p_q_self, (K,))
@triton.jit
def chunk_rwkv6_fwd_kernel_inter(
q,
v,
gs,
h,
o,
A,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_o = tl.zeros([BT, BV], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BK]
b_gs = tl.load(p_gs, boundary_check=(0, 1))
# [BT, BK]
b_qg = (b_q * tl.exp(b_gs)).to(b_q.dtype)
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
# works but dkw, owing to divine benevolence
# [BT, BV]
if i_k >= 0:
b_o += tl.dot(b_qg, b_h, allow_tf32=False)
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BT]
b_A = tl.load(p_A, boundary_check=(0, 1))
b_o += tl.dot(b_A, b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_rwkv6_bwd_kernel_dh(
q,
g,
gs,
do,
dh,
dh0,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BK, BV]
b_dh *= tl.exp(b_gn)[:, None]
# [BK, BT]
b_gs = tl.load(p_gs, boundary_check=(0, 1))
b_q = (b_q * tl.exp(b_gs)).to(b_q.dtype)
# [BK, BV]
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
if USE_INITIAL_STATE:
p_dh0 = tl.make_block_ptr(dh0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_rwkv6_bwd_kernel_inter(
k,
v,
h,
g,
gs,
A,
do,
dh,
dq,
dk,
dv,
dA,
s_k_h,
s_k_t,
s_k_d,
s_v_h,
s_v_t,
s_v_d,
s_h_h,
s_h_t,
s_h_d,
scale,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_gq = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
# [BT, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_gq = tl.load(p_gq, boundary_check=(0, 1))
b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
b_k = (b_k * b_gn).to(b_k.dtype)
# [BT, BT]
b_A = tl.load(p_A, boundary_check=(0, 1))
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
if i_k == 0:
b_dv += tl.dot(b_A, b_do, allow_tf32=False)
b_do = (b_do * scale).to(b_do.dtype)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# [BT, BT]
b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
# [BT, BK]
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
b_dq = b_dq * tl.exp(b_gq)
b_dk = b_dk * b_gn
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
o_i = tl.arange(0, BT)
m_s = o_i[:, None] > o_i[None, :]
# [BT, BT]
b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
if i_k == 0:
tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_rwkv6_bwd_kernel_intra(
q,
k,
g,
gs,
dA,
dq,
dk,
s_k_h,
s_k_t,
s_k_d,
T: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr
):
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_t, i_i = i_c // NC, i_c % NC
o_k = i_k * BK + tl.arange(0, BK)
o_q = i_t * BT + i_i * BC
m_k = o_k < K
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
# [BK,]
b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0)
# [BC, BK]
b_gs = tl.load(p_gs, boundary_check=(0, 1))
b_dq = tl.zeros([BC, BK], dtype=tl.float32)
for i_j in range(0, i_i):
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
# [BC, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
# [BC, BC]
b_dA = tl.load(p_dA, boundary_check=(0, 1))
# [BC, BK]
b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
b_dq *= tl.exp(b_gs - b_gn[None, :])
o_i = tl.arange(0, BC)
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
for j in range(0, BC):
p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
# [BC,]
b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
# [BK,]
b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
b_gkj = tl.load(g + i_bh * T * K + (o_q + j) * K + o_k, mask=(m_k & ((o_q + j) < T)), other=0)
# [BC, BK]
m_i = o_i[:, None] > j
# [BC, BK]
b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_gs - b_gkj[None, :]), 0.)
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.debug_barrier()
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
# [BK,]
b_gn = tl.load(p_gn, boundary_check=(0,))
# [BC, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_dk = tl.zeros([BC, BK], dtype=tl.float32)
for i_j in range(i_i + 1, NC):
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_gs = tl.load(p_gs, boundary_check=(0, 1))
b_qg = (b_q * tl.exp(b_gs - b_gn[None, :])).to(b_q.dtype)
# [BC, BC]
b_dA = tl.load(p_dA, boundary_check=(0, 1))
# [BC, BK]
b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
b_dk *= tl.exp(b_gn[None, :] - b_gk)
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
for j in range(0, BC):
p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
p_gqj = tl.make_block_ptr(gs + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
# [BC,]
b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
# [BK,]
b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
# [BC, BK]
m_i = o_i[:, None] < j
b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
class ChunkRWKV6Function(torch.autograd.Function):
@staticmethod
@contiguous
def forward(ctx, r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level):
q = r # alias
B, H, T, K, V = *q.shape, v.shape[-1]
BT, BC = 64, 16
BK = min(64, triton.next_power_of_2(K))
BV = min(64, triton.next_power_of_2(V))
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
num_warps = 4 if BK == 64 else 2
num_stages = 1
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
h = q.new_empty(B, H, NT * K, V)
grid = (NV, NK, B * H)
chunk_rwkv6_fwd_kernel_h[grid](
k, v, g, h, h0, ht,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=h0 is not None,
STORE_FINAL_STATE=ht is not None,
num_warps=num_warps,
num_stages=num_stages
)
return h
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float)
g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
# keep cummulative normalizer in fp32
# this kernel is equivalent to
# g_org = g_org.view(B, H, NT, BT, -1)
# g = g_org.cumsum(-2).view(B, H, T, -1)
# gs = g - g_org
chunk_rwkv6_fwd_kernel_cum[grid](
g_org, g, gs,
g.stride(1), g.stride(2), g.stride(3),
T=T, S=K, BT=BT
)
h = fwd_inner(
q=q, k=k, v=v, g=g,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
h0=initial_state if initial_state is not None else None,
ht=final_state if final_state is not None else None
)
A = q.new_zeros(NK, B, H, T, BT)
grid = (NK, NT * NC * NC, B * H)
chunk_rwkv6_fwd_kernel_intra[grid](
q, k, g, gs, u, A,
k.stride(1), k.stride(2), k.stride(3),
scale,
H=H, T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, DK=K,
num_warps=num_warps,
num_stages=num_stages
)
A = A.sum(0, dtype=A.dtype)
o = torch.empty_like(v)
grid = (NV, NT, B * H)
chunk_rwkv6_fwd_kernel_inter[grid](
q, v, gs, h, o, A,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
scale,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
if checkpoint_level > 1:
del h
h, initial_state = None, None
del g, gs
ctx.save_for_backward(q, k, v, g_org, u, h, initial_state, A)
ctx.BT = BT
ctx.scale = scale
ctx.checkpoint_level = checkpoint_level
return o, final_state
@staticmethod
@contiguous
def backward(ctx, do, dht=None):
q, k, v, g, u, h, initial_state, A = ctx.saved_tensors
B, H, T, K, V = *q.shape, v.shape[-1]
BT, BC = ctx.BT, 16
BK = min(64, triton.next_power_of_2(K))
BV = min(64, triton.next_power_of_2(V))
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
NK = triton.cdiv(K, BK)
num_warps = 4 if BK == 64 else 2
num_stages = 1
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
h = q.new_empty(B, H, NT * K, V)
grid = (NV, NK, B * H)
chunk_rwkv6_fwd_kernel_h[grid](
k, v, g, h, h0, ht,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=h0 is not None,
STORE_FINAL_STATE=ht is not None,
num_warps=num_warps,
num_stages=num_stages
)
return h
def bwd_inner(q, g, gs, h0, do, B, H, T, K, V, BT, BK, BV, NT, scale):
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
dh = q.new_empty(B, H, NT * K, V)
dh0 = torch.empty_like(h0) if h0 is not None else None
grid = (NK, NV, B * H)
chunk_rwkv6_bwd_kernel_dh[grid](
q, g, gs, do, dh, dh0,
q.stride(1), q.stride(2), q.stride(3),
do.stride(1), do.stride(2), do.stride(3),
dh.stride(1), dh.stride(2), dh.stride(3),
scale,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=h0 is not None,
num_warps=num_warps,
num_stages=num_stages
)
return dh, dh0
# recompute cumulative log decays.
g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
# keep cummulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_rwkv6_fwd_kernel_cum[grid](
g_org, g, gs,
g.stride(1), g.stride(2), g.stride(3),
T=T, S=K, BT=BT
)
# rerun the forward pass to get h if checkpoint_level >= 1
if ctx.checkpoint_level == 1:
h = fwd_inner(
q=q, k=k, v=v, g=g,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
h0=initial_state if initial_state is not None else None,
ht=None
)
scale = ctx.scale
dh, dh0 = bwd_inner(
q, g, gs, initial_state, do,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
scale=scale
)
dq = torch.empty_like(q, dtype=torch.float)
dk = torch.empty_like(k, dtype=torch.float)
dv = v.new_empty(NK, *v.shape)
dA = q.new_zeros(B, H, T, BT)
grid = (NK, NT, B * H)
chunk_rwkv6_bwd_kernel_inter[grid](
k, v, h, g, gs, A, do, dh, dq, dk, dv, dA,
k.stride(1), k.stride(2), k.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2), h.stride(3),
scale,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
dv = dv.sum(0, dtype=dv.dtype)
grid = (NK, NT * NC, B * H)
chunk_rwkv6_bwd_kernel_intra[grid](
q, k, g, gs, dA, dq, dk,
k.stride(1), k.stride(2), k.stride(3),
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
num_warps=num_warps,
num_stages=num_stages
)
# TODO: fuse?
dg = (dq * q)[:, :, 1:] - (dk * k)[:, :, 0:-1]
dg = torch.nn.functional.pad(dg, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
dg = chunk_reversed_cumsum_fwd(dg).to(g)
# equivalent to the following pytorch code.
# du = ((do * v).sum(-1)[..., None] * k * q * scale).sum(-2).to(u)
# dq += ((do * v).sum(-1)[..., None] * k * scale * u[:, :, None, :])
# dk += ((do * v).sum(-1)[..., None] * q * scale * u[:, :, None, :])
BT = 64
grid = (triton.cdiv(T, BT), B * H)
du = torch.empty_like(g, dtype=torch.float)
post_process_grad[grid](
q, k, v, u, do, dk, dq, du, scale,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3), H=H,
T=T, BT=BT, K=K, V=V, BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V),
num_warps=4
)
du = du.sum([0, 2])
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None
def chunk_rwkv6(
r: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
u: torch.Tensor,
scale: Optional[int] = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
checkpoint_level: Optional[int] = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
r (torch.Tensor):
reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
k (torch.Tensor):
keys of shape `(B, H, T, K)`
v (torch.Tensor):
values of shape `(B, H, T, V)`
w (torch.Tensor):
data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
u (torch.Tensor):
bonus of shape `(H, K)`
scale (Optional[int]):
Scale factor for the RWKV6 attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `(B, H, K, V)`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
checkpoint_level (Optional[int]):
Checkpointing level; higher values will save more memories and do more recomputations during backward.
Default: `0`:
- Level `0`: store forward hidden states for backprop.
- Level `1`: recompute the forward hidden states during backward.
"""
assert checkpoint_level in [0, 1]
if scale is None:
scale = r.shape[-1] ** -0.5
o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level)
return o, final_state
if __name__ == "__main__":
import torch.nn.functional as F
from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
B = 4
H = 4
L = 1024
K = 100
V = 120
torch.manual_seed(0)
dtype = torch.float32
q = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True)
k = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True)
v = torch.randn(B, H, L, V).cuda().to(dtype).requires_grad_(True)
w = (-torch.randn(B, H, L, K).exp()).cuda().to(torch.float32).requires_grad_(True)
u = torch.randn(H, K).cuda().to(dtype).requires_grad_(True)
h0 = torch.randn(B, H, K, V).cuda().to(dtype).requires_grad_(True)
do = torch.rand_like(v).cuda()
o, ht = fused_recurrent_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True)
o.backward(do)
dq, q.grad = q.grad.clone(), None
dk, k.grad = k.grad.clone(), None
dv, v.grad = v.grad.clone(), None
dw, w.grad = w.grad.clone(), None
du, u.grad = u.grad.clone(), None
dh0, h0.grad = h0.grad.clone(), None
o2, ht2 = chunk_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True)
o2.backward(do)
torch.testing.assert_close(o, o2, rtol=0, atol=1e-4)
torch.testing.assert_close(ht, ht2, rtol=0, atol=1e-4)
torch.testing.assert_close(q.grad, dq, rtol=0, atol=1e-4)
torch.testing.assert_close(k.grad, dk, rtol=0, atol=1e-4)
torch.testing.assert_close(v.grad, dv, rtol=0, atol=1e-4)
torch.testing.assert_close(w.grad, dw, rtol=0, atol=1e-4)
torch.testing.assert_close(u.grad, du, rtol=0, atol=2e-4)
torch.testing.assert_close(h0.grad, dh0, rtol=0, atol=2e-4)
print("All tests passed!")
@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['T'],
# different possible values for `x_name`
x_vals=[128 * 2 ** i for i in range(0, 8)],
# argument name whose value corresponds to a different line in the plot
line_arg='provider',
# possible values for `line_arg``
line_vals=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'],
# label name for the lines
line_names=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'],
# line styles
styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
ylabel="Execution Time (ms)", # label name for the y-axis
# name for the plot. Used also as a file name for saving the plot.
plot_name="Performance",
args={},
)
)
def benchmark(T, provider):
device = 'cuda'
dtype = torch.bfloat16
requires_grad = True
B, H, K = 16, 4, 128
q = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
k = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
v = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
w = F.logsigmoid(torch.randn(B, H, T, K)).to(dtype=dtype, device=device).requires_grad_(True)
u = torch.randn(H, K, device=device, requires_grad=requires_grad, dtype=dtype)
do = torch.ones_like(q, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
results = 0, 0, 0
if provider == 'recurrent':
results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u), quantiles=quantiles)
if provider == 'chunk':
results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u), quantiles=quantiles)
if provider == 'recurrent_bwd':
results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u)
[0].backward(do), quantiles=quantiles)
if provider == 'chunk_bwd':
results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u)[0].backward(do), quantiles=quantiles)
return results
benchmark.run(print_data=True)

View File

@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
from fla.ops.rwkv6.chunk import chunk_rwkv6
from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
def naive_chunk_rwkv6(
q,
k,
v,
w,
u,
chunk_size=32,
initial_state=None,
output_final_state=True,
):
assert q.shape[-2] % chunk_size == 0
orig_dtype = q.dtype
num_chunk = q.shape[-2] // chunk_size
u = u.unsqueeze(0)
q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w))
w_cumsum = w.cumsum(-2)
kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp()
wkv = kw.transpose(-1, -2) @ v
wkv_new = torch.zeros_like(wkv)
for i in range(num_chunk - 1):
wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i]
o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp()))
o_intra = torch.zeros_like(o_inter)
for i in range(chunk_size):
attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1)
mask = (torch.arange(0, chunk_size) < i).to(attn.device)
attn.masked_fill_(~mask, 0)
intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2)
intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i]
o_intra[:, :, :, i] = intra_inter_o + intra_intra_o
o = o_inter + o_intra
return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype)
if __name__ == "__main__":
B = 4
H = 4
L = 1024
D = 100
dtype = torch.bfloat16
require_grad = True
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad)
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad)
v = torch.randn(B, H, L, 2*D).cuda().to(dtype).requires_grad_(require_grad)
w = torch.nn.functional.logsigmoid(torch.randn(B, H, L, D)).cuda().to(dtype).requires_grad_(require_grad)
u = (torch.randn(H, D).cuda().to(dtype)).requires_grad_(require_grad)
do = torch.rand_like(v).cuda()
o2, _ = chunk_rwkv6(q, k, v, w.clone(), u)
o, _ = fused_recurrent_rwkv6(q, k, v, w, u, scale=1.0)
o.backward(do)
dq, q.grad = q.grad.clone(), None
dk, k.grad = k.grad.clone(), None
dv, v.grad = v.grad.clone(), None
dw, w.grad = w.grad.clone(), None
du, u.grad = u.grad.clone(), None
print((o - o2).abs().max())
o2.backward(do)
print((o-o2).abs().max())
print((q.grad - dq).abs().max())
print((k.grad - dk).abs().max())
print((v.grad - dv).abs().max())
print((w.grad - dw).abs().max())
print((u.grad - du).abs().max())

View File

@@ -0,0 +1,378 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.ops.utils import chunk_reversed_cumsum_fwd
from fla.utils import contiguous
@triton.jit
def fused_recurrent_rwkv6_fwd_kernel(
q, # query [B, H, T, K]
k, # key [B, H, T, K]
v, # value [B, H, T, V]
w, # log gate [B, H, T, K]
u, # bonus [B, H, K]
o, # output [B, H, T, V]
# initial hidden state initialization [B, H, K, V]
h0,
ht, # final hidden state [B, H, K, V]
s_k_h, # stride size: T * K
s_v_h, # stride size: T * V
scale, # K ** -0.5
B: tl.constexpr,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
mask_bk = (i_k * BK + tl.arange(0, BK)) < K
mask_bv = (i_v * BV + tl.arange(0, BV)) < V
mask_kv = mask_bv[:, None] & mask_bk[None, :]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
for _ in range(0, T):
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
b_w = tl.exp(b_w)
b_kv = b_k[None, :] * b_v[:, None]
b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]
b_o = tl.sum(b_o, axis=1)
b_h = b_h * b_w[None, :]
b_h += b_kv
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
p_q += -K if REVERSE else K
p_k += -K if REVERSE else K
p_o += -V if REVERSE else V
p_v += -V if REVERSE else V
p_w += -K if REVERSE else K
if STORE_FINAL_STATE:
p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
@triton.jit
def fused_recurrent_rwkv6_bwd_kernel_dq(
# B: B, H: H, T: T, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
k, # key [B, H, T, V]
v, # value [B, H, T, V]
w, # log gate [B, H, T, K]
u, # bonus [B, H, K]
do, # gradient of output [B, H, T, V]
dq, # gradient of query [NV, B, H, T, K]
dq_aux, # gradient of query_aux [NV, B, H, T, K]
# initial hidden state initialization [B, H, K, V]
h0,
s_k_h, # stride size: T * K
s_v_h, # stride size: T * V
scale, # K ** -0.5
B: tl.constexpr, # B
H: tl.constexpr, # H
T: tl.constexpr, # T
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
K: tl.constexpr, # K
V: tl.constexpr, # V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_dq_aux = dq_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
mask_bk = i_k * BK + tl.arange(0, BK) < K
mask_bv = i_v * BV + tl.arange(0, BV) < V
mask_kv = mask_bv[:, None] & mask_bk[None, :]
b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
for _ in range(0, T):
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
b_kv = b_k[None, :] * b_v[:, None]
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
b_w = tl.exp(b_w)
h_q = b_h * b_do[:, None]
b_dq = tl.sum(h_q + b_kv * b_u[None, :] * b_do[:, None], axis=0)
b_dq *= scale
b_dq_aux = tl.sum(h_q, axis=0)
b_h = b_h * b_w[None, :]
b_h += b_kv
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)
tl.store(p_dq_aux, b_dq_aux.to(p_dq_aux.dtype.element_ty), mask=mask_bk)
p_k += -K if REVERSE else K
p_do += -V if REVERSE else V
p_v += -V if REVERSE else V
p_w += -K if REVERSE else K
p_dq += -K if REVERSE else K
p_dq_aux += -K if REVERSE else K
@triton.jit
def fused_recurrent_rwkv6_bwd_kernel_dkv(
# B: B, H: H, T: T, D: d_head
# NV: number of split in the V dimension. NK: number of split in the K dimension
q, # query [B, H, T, K]
k, # key [B, H, T, V]
v, # value [B, H, T, V]
w, # log gate [B, H, T, K]
u, # bonus [B, H, K]
do, # gradient of output [B, H, T, V]
dk,
dk_aux,
dv,
dh0,
# initial hidden state initialization [B, H, K, V]
s_k_h, # stride size: T * K
s_v_h, # stride size: T * V
scale, # K ** -0.5
B, # B
H, # H
T, # T
BK: tl.constexpr, # BLOCK SIZE along the K dimension
BV: tl.constexpr, # BLOCK SIZE along the V dimension
K: tl.constexpr, # K
V: tl.constexpr, # V
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
):
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h = i_bh % H
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_dk_aux = dk_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
mask_bk = i_k * BK + tl.arange(0, BK) < K
mask_bv = i_v * BV + tl.arange(0, BV) < V
mask_kv = mask_bk[:, None] & mask_bv[None, :]
p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
for _ in range(T-1, -1, -1):
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
b_dkv = b_q[:, None] * b_do[None, :]
b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
tl.store(p_dk_aux, b_dk.to(p_dk_aux.dtype.element_ty), mask=mask_bk)
b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], axis=1)
b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], axis=0)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
b_dh *= tl.exp(b_w)[:, None]
b_dh += b_dkv
p_q += K if REVERSE else -K
p_k += K if REVERSE else -K
p_v += V if REVERSE else -V
p_w += K if REVERSE else -K
p_do += V if REVERSE else -V
p_dk += K if REVERSE else -K
p_dk_aux += K if REVERSE else -K
p_dv += V if REVERSE else -V
if USE_INITIAL_STATE:
p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv)
class FusedRecurrentRWKV6Function(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):
# alias
q = r
B, H, T, K, V = *q.shape, v.shape[-1]
BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 1
if output_final_state:
final_state = q.new_empty(B, H, K, V)
else:
final_state = None
o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
grid = (NV, NK, B * H)
fused_recurrent_rwkv6_fwd_kernel[grid](
q, k, v, w, u, o, initial_state, final_state,
k.stride(1),
v.stride(1),
scale,
B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None,
REVERSE=reverse,
num_warps=num_warps,
num_stages=num_stages
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, w, u, initial_state, o)
ctx.scale = scale
ctx.reverse = reverse
# we do not need the gradient of the final state from the next chunk
# similiar to Trunctated BPTT
if final_state is not None:
final_state = final_state.detach()
return o.to(q.dtype), final_state
@staticmethod
@contiguous
@custom_bwd
def backward(ctx, do, d_final_state=None):
q, k, v, w, u, initial_state, o = ctx.saved_tensors
B, H, T, K, V = *q.shape, v.shape[-1]
scale = ctx.scale
BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 1
dq = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
dq_aux = torch.empty_like(dq)
grid = (NV, NK, B * H)
fused_recurrent_rwkv6_bwd_kernel_dq[grid](
k, v, w, u, do, dq, dq_aux, initial_state,
q.stride(1),
v.stride(1),
scale,
B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
REVERSE=ctx.reverse,
)
dq = dq.sum(0).to(q)
dq_aux = dq_aux.sum(0)
BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
dk = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
dk_aux = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
dv = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
dh0 = initial_state.new_empty(B, H, K, V) if initial_state is not None else None
grid = (NV, NK, B * H)
fused_recurrent_rwkv6_bwd_kernel_dkv[grid](
q, k, v, w, u, do, dk, dk_aux, dv, dh0,
q.stride(1),
v.stride(1),
scale,
B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages,
USE_INITIAL_STATE=initial_state is not None,
REVERSE=ctx.reverse,
)
dk = dk.sum(0).to(k)
dv = dv.sum(0).to(v)
dk_aux = dk_aux.sum(0)
dw = (dq_aux * q * scale)[:, :, 1:] - (dk_aux * k)[:, :, 0:-1]
dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
dw = chunk_reversed_cumsum_fwd(dw).to(w)
du = ((do * v).sum(-1)[..., None] * k * q * scale).sum([0, -2]).to(u)
return dq, dk, dv, dw, du, None, dh0, None, None
def fused_recurrent_rwkv6(
r: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
scale: int = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
causal: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
r (torch.Tensor):
reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
k (torch.Tensor):
keys of shape `(B, H, T, K)`
v (torch.Tensor):
values of shape `(B, H, T, V)`
w (torch.Tensor):
data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
u (torch.Tensor):
bonus of shape `(H, K)`
scale (Optional[int]):
Scale factor for the RWKV6 attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `(B, H, K, V)`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
"""
if scale == -1:
scale = r.shape[-1] ** -0.5
o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)
return o, final_state

View File

@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
from typing import Optional
import torch
def naive_recurrent_rwkv6(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
scale: Optional[float] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: Optional[bool] = False
):
orig_dtype = q.dtype
B, H, T, K, V = *q.shape, v.shape[-1]
q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
o = torch.zeros_like(v)
if scale is None:
scale = K ** -0.5
if initial_state is not None:
h += initial_state
for i in range(T):
q_i = q[:, :, i, :] * scale
k_i = k[:, :, i]
v_i = v[:, :, i, :]
w_i = w[:, :, i].exp()
kv_i = k_i[..., None] * v_i[..., None, :]
o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
o[:, :, i] = o_i.sum(-2)
h = h * w_i[..., None] + kv_i
ht = h if output_final_state else None
return o.to(orig_dtype), ht
def naive_recurrent_rwkv6_bwd(
q,
k,
v,
w,
u,
o,
do,
initial_state=None,
output_final_state=False
):
q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
B, H, T, K, V = *q.shape, v.shape[-1]
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
dq = torch.zeros_like(q)
dq_aux = torch.zeros_like(q)
if initial_state is not None:
h += initial_state
for i in range(T):
k_i = k[:, :, i]
v_i = v[:, :, i]
w_i = w[:, :, i].exp()
kv_i = k_i[..., None] * v_i[..., None, :]
h_i = (h + u[None, ..., None] * kv_i)
dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
dq[:, :, i] = dq_i
dq_aux[:, :, i] = dq_aux_i
h = h * w_i[..., None] + kv_i
du = torch.zeros_like(u)
dh = torch.zeros_like(h)
dk = torch.zeros_like(k)
dk_aux = torch.zeros_like(k)
dv = torch.zeros_like(v)
for i in range(T - 1, -1, -1):
d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
k_i = k[:, :, i]
v_i = v[:, :, i]
du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
du += du_i
dk_i = (dh * v_i[..., None, :]).sum(-1)
dk_aux[:, :, i] = dk_i
dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
dv_i += (dh * k_i[..., None]).sum(-2)
dk[:, :, i] = dk_i
dv[:, :, i] = dv_i
dh = dh * w[:, :, i, :, None].exp() + d_kv_i
# dw = q * dq_aux - k * dk_aux
dw = torch.zeros_like(w)
for i in range(T - 2, -1, -1):
dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
return dq, dk, dv, dw, du

View File

@@ -0,0 +1,5 @@
- Simple GLA
Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA.
$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.

View File

@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
from .chunk import chunk_simple_gla
__all__ = [
'chunk_simple_gla'
]

View File

@@ -0,0 +1,415 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang
from typing import Tuple
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from fla.utils import contiguous
@torch.jit.script
def normalize_output(q, k, o):
k = k.transpose(-2, -1)
k = k.cumsum(-1)
k = k.transpose(-2, -1)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-5)
@triton.jit
def chunk_simple_gla_fwd_kernel_h(
k,
v,
h,
g,
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V,
(K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
for i_t in range(NT):
p_k = tl.make_block_ptr(
k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(
v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BK, BV]
b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)
b_h *= tl.math.exp2(b_g_last)
b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))
b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(
final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_simple_gla_fwd_kernel_o(
q,
k,
v,
h,
g,
o,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
o_i = tl.arange(0, BT)
m_s = o_i[:, None] >= o_i[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_s = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(
k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT]
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_o += tl.dot(b_q, b_h, allow_tf32=False)
b_s += tl.dot(b_q, b_k, allow_tf32=False)
p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
b_g = tl.load(p_g)
b_o = b_o * tl.math.exp2(b_g)[:, None]
b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])
b_s = tl.where(m_s, b_s, 0)
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V),
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V),
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def chunk_simple_gla_bwd_kernel_dh(
q,
g,
do,
dh,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V,
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T +
i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)
# [BT, V]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BV]
b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))
b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
@triton.jit
def chunk_simple_gla_bwd_kernel_dqkv(
q,
k,
v,
h,
g,
do,
dh,
dq,
dk,
dv,
s_qk_h,
s_qk_t,
s_qk_d,
s_vo_h,
s_vo_t,
s_vo_d,
s_h_h,
s_h_t,
scale,
B: tl.constexpr,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NT: tl.constexpr
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
n_bh = tl.num_programs(2)
o_i = tl.arange(0, BT)
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T),
(s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K),
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_s = tl.dot(b_k, b_q, allow_tf32=False)
p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
b_g = tl.load(p_g)
b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)
mask = tl.math.exp2(b_g[None, :] - b_g[:, None])
mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)
b_s = b_s * mask
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t),
(i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V),
(s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V),
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BT]
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
# [BT, BK]
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + \
tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
b_dq = b_dq * tl.math.exp2(b_g)[:, None]
b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]
b_ds = b_ds * tl.trans(mask)
b_ds = b_ds.to(b_k.dtype)
# [BT, BK]
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K),
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K),
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
class SimpleGLAFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@contiguous
def forward(ctx, q, k, v, g, initial_state, output_final_state):
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(
64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
scale = K ** -0.5
BT = 64
assert T % BT == 0, 'sequence length must be divisible by BT'
g = g.reshape(B, H, -1, BT)
g = g.cumsum(-1) * 1.44269504
g = g.reshape(B, H, -1)
final_state = None
if output_final_state:
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
h = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_simple_gla_fwd_kernel_h[grid](
k, v, h, g, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NV, NT, B * H)
o = torch.empty_like(v)
chunk_simple_gla_fwd_kernel_o[grid](
q, k, v, h, g, o,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)
ctx.save_for_backward(q, k, v, h, g)
return o.to(q.dtype), final_state
@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_ht=None):
q, k, v, h, g = ctx.saved_tensors
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(
32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
scale = K ** -0.5
dh = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
chunk_simple_gla_bwd_kernel_dh[grid](
q, g, do, dh,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
grid = (NK, NT, B * H)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = v.new_empty(NK, *v.shape)
num_stages = 1
num_warps = 4 if BK == 64 else 2
chunk_simple_gla_bwd_kernel_dqkv[grid](
q, k, v, h, g, do, dh, dq, dk, dv,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
dv = dv.sum(0)
dg = (dq * q - dk * k).sum(-1)
def rev_cumsum(x):
cumsum_x = x.cumsum(-1)
rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x
return rev_cumsum_x + x
dg = rev_cumsum(dg)
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None
def chunk_simple_gla(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor, # log decay
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
g = g.float()
o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state)
return o, final_state

View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
def torch_simple_gla(q, k, v, g, chunk_size=64):
q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5)
k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size)
g = g.cumsum(-1)
kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None])
S = torch.zeros_like(kv)
for i in range(1, g.shape[-2]):
S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1]
inter = (q * g[..., None].exp()) @ S
attn = q @ k.transpose(-1, -2)
attn = attn * (g[..., None] - g[..., None, :]).exp()
attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
intra = attn @ v
o = inter + intra
return rearrange(o, 'b h n c d -> b h (n c) d')
def torch_simple_gla_recurrent(q, k, v, g, chunk_size=64):
# q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5)
# k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
# v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
# g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size)
# g = g.cumsum(-1)
# kv = k.transpose(-1, -2) @ v
B, H, T, DK = q.shape
q = q * (DK ** -0.5)
_, _, _, DV = v.shape
S = torch.zeros(B, H, DK, DV).to(q)
o = torch.zeros(B, H, T, DV).to(q)
for i in range(T):
gate = g[:, :, i].exp()
key = k[:, :, i]
value = v[:, :, i]
kv = key.unsqueeze(-1) * value.unsqueeze(-2)
S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv
q_i = q[:, :, i, :]
o_i = (q_i.unsqueeze(-1) * S).sum(-2)
o[:, :, i] = o_i
return o

579
finetune/lora/v6/fla/ops/utils.py vendored Normal file
View File

@@ -0,0 +1,579 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
from typing import Optional
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 16}, num_warps=4),
triton.Config({'BT': 16}, num_warps=8),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=8),
triton.Config({'BT': 64}, num_warps=2),
triton.Config({'BT': 64}, num_warps=4),
triton.Config({'BT': 64}, num_warps=8),
],
key=['S']
)
@triton.jit
def logcumsumexp_fwd_kernel(
s,
z,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr
):
i_bh = tl.program_id(0)
o_i = tl.arange(0, BT)
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
b_mp = tl.full([S,], float('-inf'), dtype=tl.float32)
b_zp = tl.zeros([S,], dtype=tl.float32)
for i_t in range(tl.cdiv(T, BT)):
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
# [BT, S]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
# [S,]
b_mc = tl.max(b_s, 0)
# workaround for compiler bugs
if i_t > 0:
b_mc = tl.maximum(b_mp, b_mc)
b_zp = b_zp * tl.exp(b_mp - b_mc)
# [BT, S]
b_s = tl.exp(b_s - b_mc)
b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp
# [S,]
b_zc = tl.max(b_z, 0)
b_mp = b_mc
b_zp = b_zc
# [BT, BS]
# small eps to prevent underflows
b_z = tl.log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc
tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['S']
)
@triton.jit
def softmax_fwd_kernel(
s,
p,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
# [BT, S]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
# [BT]
b_m = tl.max(b_s, 1)
# [BT, BS]
b_s = tl.exp(b_s - b_m[:, None])
b_z = tl.sum(b_s, 1)
b_p = tl.where(b_s != 0, b_s / b_z[:, None], 0.)
tl.store(p_p, b_p.to(p_p.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['S']
)
@triton.jit
def softmax_bwd_kernel(
p,
dp,
ds,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
p_dp = tl.make_block_ptr(dp + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
# [BT, BS]
b_p = tl.load(p_p, boundary_check=(0, 1)).to(tl.float32)
b_dp = tl.load(p_dp, boundary_check=(0, 1)).to(tl.float32)
# [BT,]
b_pp = tl.sum(b_p * b_dp, 1)
# [BT, BS]
b_ds = b_p * b_dp - b_p * b_pp[:, None]
tl.store(p_ds, b_ds.to(p_ds.dtype.element_ty), boundary_check=(0, 1))
@triton.autotune(
configs=[
triton.Config({'BS': 32}, num_warps=2),
triton.Config({'BS': 32}, num_warps=4),
triton.Config({'BS': 32}, num_warps=8),
triton.Config({'BS': 64}, num_warps=2),
triton.Config({'BS': 64}, num_warps=4),
triton.Config({'BS': 64}, num_warps=8),
triton.Config({'BS': 128}, num_warps=2),
triton.Config({'BS': 128}, num_warps=4),
triton.Config({'BS': 128}, num_warps=8),
],
key=['S']
)
@triton.jit
def recurrent_cumsum_fwd_kernel(
s,
z,
s_s_h,
s_s_t,
T: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
o_s = i_s * BS + tl.arange(0, BS)
mask = o_s < S
b_z = tl.zeros([BS], dtype=tl.float32)
for i_t in range(0, T):
# [BS]
b_s = tl.load(s + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32)
b_z = b_z + b_s
tl.store(z + i_bh * s_s_h + i_t * s_s_t + o_s, b_z.to(s.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'BS': 32}, num_warps=2),
triton.Config({'BS': 32}, num_warps=4),
triton.Config({'BS': 32}, num_warps=8),
triton.Config({'BS': 64}, num_warps=2),
triton.Config({'BS': 64}, num_warps=4),
triton.Config({'BS': 64}, num_warps=8),
triton.Config({'BS': 128}, num_warps=2),
triton.Config({'BS': 128}, num_warps=4),
triton.Config({'BS': 128}, num_warps=8),
],
key=['S']
)
@triton.jit
def recurrent_cumsum_bwd_kernel(
ds,
dz,
s_s_h,
s_s_t,
T: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
o_s = i_s * BS + tl.arange(0, BS)
mask = o_s < S
b_ds = tl.zeros([BS], dtype=tl.float32)
for i_t in range(T - 1, -1, -1):
# [BS]
b_dz = tl.load(dz + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32)
b_ds = b_ds + b_dz
tl.store(ds + i_bh * s_s_h + i_t * s_s_t + o_s, b_ds.to(ds.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 16}, num_warps=4),
triton.Config({'BT': 16}, num_warps=8),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=8),
triton.Config({'BT': 64}, num_warps=2),
triton.Config({'BT': 64}, num_warps=4),
triton.Config({'BT': 64}, num_warps=8),
],
key=['S']
)
@triton.jit
def chunk_cumsum_fwd_kernel(
s,
z,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
o_i = tl.arange(0, BT)
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
b_z = tl.zeros([BS], dtype=tl.float32)
for i_t in range(tl.cdiv(T, BT)):
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
if i_t >= 0:
b_z += tl.sum(b_s, 0)
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 16}, num_warps=4),
triton.Config({'BT': 16}, num_warps=8),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=8),
triton.Config({'BT': 64}, num_warps=2),
triton.Config({'BT': 64}, num_warps=4),
triton.Config({'BT': 64}, num_warps=8),
],
key=['S']
)
@triton.jit
def chunk_cumsum_bwd_kernel(
ds,
dz,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
o_i = tl.arange(0, BT)
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
b_ds = tl.zeros([BS], dtype=tl.float32)
for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32)
b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False)
tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1))
if i_t >= 0:
b_ds += tl.sum(b_dz, 0)
@contiguous
def chunk_cumsum_fwd(
s: torch.Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
B, H, T, S = s.shape
BS = 32
dtype = dtype or s.dtype
grid = (triton.cdiv(S, BS), B * H)
z = torch.empty_like(s, dtype=dtype)
chunk_cumsum_fwd_kernel[grid](
s, z,
s.stride(1), s.stride(2), s.stride(3),
T=T, S=S, BS=BS
)
return z
@contiguous
def chunk_cumsum_bwd(
dz: torch.Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
B, H, T, S = dz.shape
BS = 32
dtype = dtype or dz.dtype
grid = (triton.cdiv(S, BS), B * H)
ds = torch.empty_like(dz, dtype=dtype)
chunk_cumsum_bwd_kernel[grid](
ds, dz,
ds.stride(1), ds.stride(2), ds.stride(3),
T=T, S=S, BS=BS
)
return ds
class CumsumFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, s, dtype):
z = chunk_cumsum_fwd(s, dtype)
ctx.dtype = dtype
return z
@staticmethod
def backward(ctx, dz):
ds = chunk_cumsum_bwd(dz, ctx.dtype)
return ds, None
def cumsum(
s: torch.Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return CumsumFunction.apply(s, dtype)
@triton.autotune(
configs=[
triton.Config({'BS': 32}, num_warps=2),
triton.Config({'BS': 32}, num_warps=4),
triton.Config({'BS': 32}, num_warps=8),
triton.Config({'BS': 64}, num_warps=2),
triton.Config({'BS': 64}, num_warps=4),
triton.Config({'BS': 64}, num_warps=8),
triton.Config({'BS': 128}, num_warps=2),
triton.Config({'BS': 128}, num_warps=4),
triton.Config({'BS': 128}, num_warps=8),
],
key=['S']
)
@triton.jit
def recurrent_reversed_cumsum_fwd_kernel(
s,
z,
s_s_h,
s_s_t,
T: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
o_s = i_s * BS + tl.arange(0, BS)
mask = o_s < S
b_z = tl.zeros([BS], dtype=tl.float32)
for i_t in range(T - 1, -1, -1):
# [BS]
b_s = tl.load(s + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32)
b_z = b_z + b_s
tl.store(z + i_bh * s_s_h + i_t * s_s_t + o_s, b_z.to(s.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'BS': 32}, num_warps=2),
triton.Config({'BS': 32}, num_warps=4),
triton.Config({'BS': 32}, num_warps=8),
triton.Config({'BS': 64}, num_warps=2),
triton.Config({'BS': 64}, num_warps=4),
triton.Config({'BS': 64}, num_warps=8),
triton.Config({'BS': 128}, num_warps=2),
triton.Config({'BS': 128}, num_warps=4),
triton.Config({'BS': 128}, num_warps=8),
],
key=['S']
)
@triton.jit
def recurrent_reversed_cumsum_bwd_kernel(
ds,
dz,
s_s_h,
s_s_t,
T: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
o_s = i_s * BS + tl.arange(0, BS)
mask = o_s < S
b_ds = tl.zeros([BS], dtype=tl.float32)
for i_t in range(0, T):
# [BS]
b_dz = tl.load(dz + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32)
b_ds = b_ds + b_dz
tl.store(ds + i_bh * s_s_h + i_t * s_s_t + o_s, b_ds.to(ds.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 16}, num_warps=4),
triton.Config({'BT': 16}, num_warps=8),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=8),
triton.Config({'BT': 64}, num_warps=2),
triton.Config({'BT': 64}, num_warps=4),
triton.Config({'BT': 64}, num_warps=8),
],
key=['S']
)
@triton.jit
def chunk_reversed_cumsum_fwd_kernel(
s,
z,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
o_i = tl.arange(0, BT)
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
b_z = tl.zeros([BS], dtype=tl.float32)
for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
if i_t >= 0:
b_z += tl.sum(b_s, 0)
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 16}, num_warps=4),
triton.Config({'BT': 16}, num_warps=8),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=8),
triton.Config({'BT': 64}, num_warps=2),
triton.Config({'BT': 64}, num_warps=4),
triton.Config({'BT': 64}, num_warps=8),
],
key=['S']
)
@triton.jit
def chunk_reversed_cumsum_bwd_kernel(
ds,
dz,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
o_i = tl.arange(0, BT)
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
b_ds = tl.zeros([BS], dtype=tl.float32)
for i_t in range(tl.cdiv(T, BT)):
p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32)
b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False)
tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1))
if i_t >= 0:
b_ds += tl.sum(b_dz, 0)
@contiguous
def chunk_reversed_cumsum_fwd(
s: torch.Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
B, H, T, S = s.shape
BS = 32
dtype = dtype or s.dtype
grid = (triton.cdiv(S, BS), B * H)
z = torch.empty_like(s, dtype=dtype)
chunk_reversed_cumsum_fwd_kernel[grid](
s, z,
s.stride(1), s.stride(2), s.stride(3),
T=T, S=S, BS=BS
)
return z
@contiguous
def chunk_reversed_cumsum_bwd(
dz: torch.Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
B, H, T, S = dz.shape
BS = 32
dtype = dtype or dz.dtype
grid = (triton.cdiv(S, BS), B * H)
ds = torch.empty_like(dz, dtype=dtype)
chunk_reversed_cumsum_bwd_kernel[grid](
ds, dz,
ds.stride(1), ds.stride(2), ds.stride(3),
T=T, S=S, BS=BS
)
return ds
class ReversedCumsumFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, s, dtype):
z = chunk_reversed_cumsum_fwd(s, dtype)
ctx.dtype = dtype
return z
@staticmethod
def backward(ctx, dz):
ds = chunk_reversed_cumsum_bwd(dz, ctx.dtype)
return ds, None
def reversed_cumsum(
s: torch.Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return CumsumFunction.apply(s, dtype)