This commit is contained in:
18
finetune/lora/v6/fla/ops/__init__.py
vendored
Normal file
18
finetune/lora/v6/fla/ops/__init__.py
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .based import fused_chunk_based, parallel_based
|
||||
from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
|
||||
from .retention import (chunk_retention, fused_chunk_retention,
|
||||
fused_recurrent_retention, parallel_retention)
|
||||
|
||||
__all__ = [
|
||||
'fused_chunk_based',
|
||||
'parallel_based',
|
||||
'chunk_gla',
|
||||
'fused_chunk_gla',
|
||||
'fused_recurrent_gla',
|
||||
'chunk_retention',
|
||||
'fused_chunk_retention',
|
||||
'fused_recurrent_retention',
|
||||
'parallel_retention'
|
||||
]
|
||||
11
finetune/lora/v6/fla/ops/abc/__init__.py
vendored
Normal file
11
finetune/lora/v6/fla/ops/abc/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_abc
|
||||
from .chunk_gate import chunk_gated_abc
|
||||
from .recurrent_fuse import fused_recurrent_gated_abc
|
||||
|
||||
__all__ = [
|
||||
'chunk_abc',
|
||||
'chunk_gated_abc',
|
||||
'fused_recurrent_gated_abc'
|
||||
]
|
||||
1194
finetune/lora/v6/fla/ops/abc/chunk.py
vendored
Normal file
1194
finetune/lora/v6/fla/ops/abc/chunk.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1287
finetune/lora/v6/fla/ops/abc/chunk_gate.py
vendored
Normal file
1287
finetune/lora/v6/fla/ops/abc/chunk_gate.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
90
finetune/lora/v6/fla/ops/abc/naive.py
vendored
Normal file
90
finetune/lora/v6/fla/ops/abc/naive.py
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def naive_recurrent_abc(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
s: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
scale: Optional[int] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: Optional[bool] = False
|
||||
) -> torch.Tensor:
|
||||
dtype = q.dtype
|
||||
|
||||
# [batch_size, n_heads, seq_len, n_slots]
|
||||
if g is None:
|
||||
z = s.float().logcumsumexp(2)
|
||||
g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
|
||||
s = torch.exp(s - z)
|
||||
q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
|
||||
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
|
||||
|
||||
hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
|
||||
ok = torch.zeros_like(s)
|
||||
|
||||
if scale is None:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
|
||||
final_state = None
|
||||
if initial_state is not None:
|
||||
hk += initial_state[0]
|
||||
|
||||
for i in range(T):
|
||||
q_i = q[:, :, i] * scale
|
||||
k_i = k[:, :, i]
|
||||
v_i = s[:, :, i]
|
||||
g_i = g[:, :, i].exp()
|
||||
hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
|
||||
ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
|
||||
|
||||
qv = ok.softmax(-1)
|
||||
hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
|
||||
ov = torch.zeros_like(v)
|
||||
if initial_state is not None:
|
||||
hv += initial_state[1]
|
||||
|
||||
for i in range(T):
|
||||
q_i = qv[:, :, i]
|
||||
k_i = s[:, :, i]
|
||||
v_i = v[:, :, i]
|
||||
g_i = g[:, :, i].exp()
|
||||
hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
|
||||
ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
|
||||
|
||||
if output_final_state:
|
||||
final_state = (hk, hv)
|
||||
return ov.to(dtype), final_state
|
||||
|
||||
|
||||
def naive_cumsum_abc(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
s: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper.
|
||||
This is just for demonstration purposes, with no numerical stabilities guaranteed.
|
||||
"""
|
||||
|
||||
dtype = q.dtype
|
||||
q, k, v, s = map(lambda x: x.float(), (q, k, v, s))
|
||||
|
||||
scale = q.shape[-1] ** -0.5
|
||||
# [batch_size, n_heads, seq_len, n_slots]
|
||||
s = (s - s.max(2, True)[0]).exp()
|
||||
z = s.cumsum(2)
|
||||
# [batch_size, n_heads, seq_len, n_slots, d_head]
|
||||
K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
|
||||
V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1)
|
||||
# [batch_size, n_heads, seq_len, n_slots]
|
||||
p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1)
|
||||
# [batch_size, n_heads, seq_len, d_head]
|
||||
o = torch.einsum('...m,...md->...d', p, V)
|
||||
return o.to(dtype), None
|
||||
388
finetune/lora/v6/fla/ops/abc/recurrent_fuse.py
vendored
Normal file
388
finetune/lora/v6/fla/ops/abc/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,388 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2024, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_gated_abc_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
gk,
|
||||
gv,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
s_k_h,
|
||||
s_v_h,
|
||||
scale,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
USE_GK: tl.constexpr,
|
||||
USE_GV: tl.constexpr,
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
|
||||
if USE_GK:
|
||||
p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
if USE_GV:
|
||||
p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
|
||||
mask_bk = (i_k * BK + tl.arange(0, BK)) < K
|
||||
mask_bv = (i_v * BV + tl.arange(0, BV)) < V
|
||||
|
||||
h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
mask_kv = mask_bk[None, :] & mask_bv[:, None]
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
if USE_GK:
|
||||
b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
|
||||
h = h * b_gk[None, :]
|
||||
if USE_GV:
|
||||
b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
|
||||
h = h * b_gv[:, None]
|
||||
h += b_k[None, :] * b_v[:, None]
|
||||
b_o = h * b_q[None, :]
|
||||
b_o = tl.sum(b_o, axis=1)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
|
||||
p_q += -K if REVERSE else K
|
||||
p_k += -K if REVERSE else K
|
||||
p_o += -V if REVERSE else V
|
||||
p_v += -V if REVERSE else V
|
||||
if USE_GK:
|
||||
p_gk += -K if REVERSE else K
|
||||
if USE_GV:
|
||||
p_gv += -V if REVERSE else V
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_gated_abc_bwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
gk,
|
||||
gv,
|
||||
do,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
h0,
|
||||
s_k_h,
|
||||
s_v_h,
|
||||
scale,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
USE_GK: tl.constexpr,
|
||||
USE_GV: tl.constexpr,
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
if USE_GK:
|
||||
p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
if USE_GV:
|
||||
p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < K
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < V
|
||||
mask_kv = mask_bk[:, None] & mask_bv[None, :]
|
||||
h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
|
||||
h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
if USE_GK:
|
||||
b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
|
||||
h = h * b_gk[:, None]
|
||||
if USE_GV:
|
||||
b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
|
||||
h = h * b_gv[None, :]
|
||||
h += b_k[:, None] * b_v[None, :]
|
||||
b_dq = tl.sum(h * b_do[None, :], axis=1) * scale
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)
|
||||
|
||||
p_k += -K if REVERSE else K
|
||||
p_v += -V if REVERSE else V
|
||||
p_q += -K if REVERSE else K
|
||||
p_do += -V if REVERSE else V
|
||||
p_dq += -K if REVERSE else K
|
||||
if USE_GK:
|
||||
p_gk += -K if REVERSE else K
|
||||
if USE_GV:
|
||||
p_gv += -V if REVERSE else V
|
||||
|
||||
# sync threads
|
||||
tl.debug_barrier()
|
||||
|
||||
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
if USE_GK:
|
||||
p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
if USE_GV:
|
||||
p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for _ in range(T):
|
||||
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_dh += b_q[:, None] * b_do[None, :]
|
||||
b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
|
||||
b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
|
||||
if USE_GK:
|
||||
b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_dh *= b_gk[:, None]
|
||||
if USE_GV:
|
||||
b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_dh *= b_gv[None, :]
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
p_q += K if REVERSE else -K
|
||||
p_k += K if REVERSE else -K
|
||||
p_v += V if REVERSE else -V
|
||||
p_do += V if REVERSE else -V
|
||||
p_dk += K if REVERSE else -K
|
||||
p_dv += V if REVERSE else -V
|
||||
if USE_GK:
|
||||
p_gk += K if REVERSE else -K
|
||||
if USE_GV:
|
||||
p_gv += V if REVERSE else -V
|
||||
|
||||
|
||||
class FusedRecurrentGatedABCFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, s, g, scale=None, initial_state=None, output_final_state=False, reverse=False):
|
||||
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
|
||||
# default scale
|
||||
if scale is None:
|
||||
scale = K ** -0.5
|
||||
|
||||
BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)
|
||||
NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
g = g.float().exp()
|
||||
|
||||
final_state = (None, None)
|
||||
if output_final_state:
|
||||
final_state = (q.new_empty(B, H, K, M), q.new_empty(B, H, M, V))
|
||||
|
||||
ok = q.new_empty(NK, B, H, T, M, dtype=torch.float)
|
||||
gk, gv = None, g
|
||||
grid = (NM, NK, B * H)
|
||||
fused_recurrent_gated_abc_fwd_kernel[grid](
|
||||
q, k, s, gk, gv, ok, initial_state[0], final_state[0],
|
||||
k.stride(1),
|
||||
s.stride(1),
|
||||
scale=scale,
|
||||
B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
|
||||
USE_INITIAL_STATE=initial_state[0] is not None,
|
||||
STORE_FINAL_STATE=final_state[0] is not None,
|
||||
USE_GK=False,
|
||||
USE_GV=True,
|
||||
REVERSE=reverse,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
ok = ok.sum(0)
|
||||
|
||||
qv = ok.softmax(-1, dtype=torch.float)
|
||||
ov = q.new_empty(NM, B, H, T, V, dtype=torch.float)
|
||||
gk, gv = g, None
|
||||
grid = (NV, NM, B * H)
|
||||
fused_recurrent_gated_abc_fwd_kernel[grid](
|
||||
qv, s, v, gk, gv, ov, initial_state[1], final_state[1],
|
||||
s.stride(1),
|
||||
v.stride(1),
|
||||
scale=1.,
|
||||
B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state[0] is not None,
|
||||
STORE_FINAL_STATE=final_state[0] is not None,
|
||||
USE_GK=True,
|
||||
USE_GV=False,
|
||||
REVERSE=reverse,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
ov = ov.sum(0)
|
||||
|
||||
ctx.save_for_backward(q, k, v, s, g, qv, *initial_state, ok)
|
||||
ctx.scale = scale
|
||||
ctx.reverse = reverse
|
||||
# we do not need the gradient of the final state from the next chunk
|
||||
# similiar to Trunctated BPTT
|
||||
if final_state is not None:
|
||||
final_state = tuple(i.detach() for i in final_state)
|
||||
return ov.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do, dht=None):
|
||||
q, k, v, s, g, qv, *initial_state, ok = ctx.saved_tensors
|
||||
B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
|
||||
V = v.shape[-1]
|
||||
scale = ctx.scale
|
||||
|
||||
BK, BV, BM = min(K, 32), min(V, 32), min(M, 32)
|
||||
NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
|
||||
dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
|
||||
dv = q.new_empty(NM, B, H, T, V, dtype=torch.float)
|
||||
gk, gv = g, None
|
||||
grid = (NV, NM, B * H)
|
||||
fused_recurrent_gated_abc_bwd_kernel[grid](
|
||||
qv, s, v, gk, gv, do, dqv, dsv, dv, initial_state[1],
|
||||
s.stride(1),
|
||||
v.stride(1),
|
||||
scale=1.,
|
||||
B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state[1] is not None,
|
||||
REVERSE=ctx.reverse,
|
||||
USE_GK=gk is not None,
|
||||
USE_GV=gv is not None
|
||||
)
|
||||
dqv = dqv.sum(0)
|
||||
dsv = dsv.sum(0)
|
||||
dv = dv.sum(0)
|
||||
dgk = dqv * qv.float() - dsv * s.float()
|
||||
dgk_cumsum = dgk.cumsum(-2)
|
||||
dgk = dgk + dgk_cumsum[:, :, -1, None] - dgk_cumsum
|
||||
|
||||
dok = qv * (dqv - (qv * dqv).sum(-1, True))
|
||||
dq = q.new_empty(NM, B, H, T, K, dtype=torch.float)
|
||||
dk = q.new_empty(NM, B, H, T, K, dtype=torch.float)
|
||||
dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float)
|
||||
gk, gv = None, g
|
||||
grid = (NM, NK, B * H)
|
||||
fused_recurrent_gated_abc_bwd_kernel[grid](
|
||||
q, k, s, gk, gv, dok, dq, dk, dsk, initial_state[0],
|
||||
q.stride(1),
|
||||
s.stride(1),
|
||||
scale=scale,
|
||||
B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state[0] is not None,
|
||||
REVERSE=ctx.reverse,
|
||||
USE_GK=gk is not None,
|
||||
USE_GV=gv is not None
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dsk = dsk.sum(0)
|
||||
|
||||
dgv = dok.float() * ok.float() - dsk * s.float()
|
||||
dgv_cumsum = dgv.cumsum(-2)
|
||||
dgv = dgv + dgv_cumsum[:, :, -1, None] - dgv_cumsum
|
||||
|
||||
ds = dsk.add_(dsv)
|
||||
dg = dgk.add_(dgv)
|
||||
|
||||
return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, None, None, None
|
||||
|
||||
|
||||
def fused_recurrent_gated_abc(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
s: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
scale: Optional[int] = None,
|
||||
initial_state: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_final_state: Optional[bool] = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `(B, H, T, K)`
|
||||
k (torch.Tensor):
|
||||
keys of shape `(B, H, T, K)`
|
||||
v (torch.Tensor):
|
||||
values of shape `(B, H, T, V)`
|
||||
g (torch.Tensor):
|
||||
Forget gates of shape `(B, H, T, M)` applied to keys.
|
||||
If not provided, this function is equivalent to vanilla ABC.
|
||||
scale (Optional[int]):
|
||||
Scale factor for attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[Tuple[torch.Tensor]]):
|
||||
Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`.
|
||||
"""
|
||||
if initial_state is not None:
|
||||
initial_state = tuple(i.detach() for i in initial_state)
|
||||
if g is None:
|
||||
# TODO: this 3 steps took huge amount of time, ought to be optimized
|
||||
z = s.float().logcumsumexp(2)
|
||||
g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
|
||||
s = torch.exp(s - z).to(k.dtype)
|
||||
if scale is None:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
ov, final_state = FusedRecurrentGatedABCFunction.apply(q, k, v, s, g, scale, initial_state, output_final_state)
|
||||
return ov, final_state
|
||||
9
finetune/lora/v6/fla/ops/based/__init__.py
vendored
Normal file
9
finetune/lora/v6/fla/ops/based/__init__.py
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk_fuse import fused_chunk_based
|
||||
from .parallel import parallel_based
|
||||
|
||||
__all__ = [
|
||||
'fused_chunk_based',
|
||||
'parallel_based'
|
||||
]
|
||||
410
finetune/lora/v6/fla/ops/based/chunk_fuse.py
vendored
Normal file
410
finetune/lora/v6/fla/ops/based/chunk_fuse.py
vendored
Normal file
@@ -0,0 +1,410 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_chunk_based_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
z, # normalizer [B, H, L, 1]
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
# [BT, BT]
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
|
||||
# [BV], zero-order taylor expansion
|
||||
b_h_0o = tl.zeros([BV], dtype=tl.float32)
|
||||
# [BK, BV], first-order taylor expansion
|
||||
b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
# [BK, BK, BV] second-order taylor expansion
|
||||
b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
|
||||
|
||||
# make block pointers
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
|
||||
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
|
||||
k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
|
||||
k_1o = tl.zeros([1, BK], dtype=tl.float32)
|
||||
k_0o = 0
|
||||
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK*BK, BT]
|
||||
b_k_2o = b_k[:, None, :] * b_k[None, :, :]
|
||||
b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_z = tl.zeros([BT], dtype=tl.float32)
|
||||
|
||||
# interchunk
|
||||
# zero-order
|
||||
b_o += b_h_0o
|
||||
b_z += k_0o
|
||||
# first-order
|
||||
b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
|
||||
b_z += tl.sum(b_q * k_1o, axis=1)
|
||||
# second-order
|
||||
b_q_2o = b_q[:, :, None] * b_q[:, None, :]
|
||||
b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
|
||||
b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
|
||||
b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
|
||||
|
||||
# update running statistics
|
||||
k_1o += tl.sum(b_k, axis=1)[None, :]
|
||||
k_2o += tl.sum(b_k_2o, axis=1)[None, :]
|
||||
k_0o += BT
|
||||
|
||||
# intrachunk
|
||||
# [BT, BT]
|
||||
b_s = tl.dot(b_q, b_k, allow_tf32=False)
|
||||
b_s = 1 + b_s + 0.5 * b_s * b_s
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
b_z += tl.sum(b_s, axis=1)
|
||||
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
|
||||
# [TB, BV]
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_z, b_z.to(p_z.dtype.element_ty),
|
||||
mask=(i * BT + tl.arange(0, BT)) < T)
|
||||
|
||||
# update hidden state
|
||||
# [BK, BV]
|
||||
b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
|
||||
b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
|
||||
b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
|
||||
|
||||
p_q = tl.advance(p_q, (BT, 0))
|
||||
p_k = tl.advance(p_k, (0, BT))
|
||||
p_v = tl.advance(p_v, (BT, 0))
|
||||
p_o = tl.advance(p_o, (BT, 0))
|
||||
p_z += BT
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_chunk_based_bwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dz, # gradient of normalizer [B, H, L]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
|
||||
# [BV], zero-order taylor expansion
|
||||
# b_h_0o = tl.zeros([BV], dtype=tl.float32)
|
||||
# [BK, BV], first-order taylor expansion
|
||||
b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
# [BK, BK, BV] second-order taylor expansion
|
||||
b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
|
||||
|
||||
k_1o = tl.zeros([1, BK], dtype=tl.float32)
|
||||
k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
|
||||
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(
|
||||
k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(
|
||||
v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h,
|
||||
(T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
|
||||
p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
|
||||
# load tensors
|
||||
# [BT, BK]
|
||||
b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BV, BT]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# inter-chunk
|
||||
b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_dq += b_dz[:, None] * k_1o
|
||||
b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
|
||||
if i_v == 0:
|
||||
b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
|
||||
b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
|
||||
b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
|
||||
b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
|
||||
b_dq *= scale
|
||||
|
||||
# intra-chunk
|
||||
# [BT, BT]
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[:, None]
|
||||
b_ds = tl.where(m_s, b_ds, 0) * scale
|
||||
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
|
||||
|
||||
# store
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# update hidden state
|
||||
# [BT, BK*BK]
|
||||
b_k_2o = b_k[:, :, None] * b_k[:, None, :]
|
||||
b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
|
||||
# [BV, BK*BK]
|
||||
b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
|
||||
# [BV, BK]
|
||||
b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
|
||||
|
||||
if i_v == 0:
|
||||
# update running statistics
|
||||
k_1o += tl.sum(b_k, axis=0)[None, :]
|
||||
k_2o += tl.sum(b_k_2o, axis=0)[None, :]
|
||||
|
||||
tl.debug_barrier()
|
||||
b_h_1o = None
|
||||
b_h_2o = None
|
||||
|
||||
# [BK, BV], first-order taylor expansion
|
||||
b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
# [BK, BK, BV] second-order taylor expansion
|
||||
b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
|
||||
b_dh_0o = tl.zeros([BV], dtype=tl.float32)
|
||||
m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
|
||||
|
||||
dq_1o = tl.zeros([1, BK], dtype=tl.float32)
|
||||
dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
|
||||
|
||||
for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(
|
||||
k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(
|
||||
v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0))
|
||||
p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
|
||||
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
|
||||
b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
|
||||
b_q = (b_q * scale).to(b_k.dtype)
|
||||
|
||||
# intra chunk
|
||||
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[None, :]
|
||||
b_ds = tl.where(m_s, b_ds, 0)
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False)
|
||||
b_s2 = 1 + b_s + 0.5 * b_s * b_s
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
b_s2 = tl.where(m_s, b_s2, 0)
|
||||
b_ds *= (1+b_s)
|
||||
|
||||
b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
|
||||
b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
|
||||
|
||||
# inter chunk
|
||||
b_k_2o = b_k[:, :, None] * b_k[:, None, :]
|
||||
b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
|
||||
|
||||
b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
|
||||
b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
|
||||
b_dv += b_dh_0o
|
||||
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
|
||||
|
||||
if i_v == 0:
|
||||
b_dk += dq_1o
|
||||
|
||||
b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype),
|
||||
tl.trans(b_v), allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_dk_2o += dq_2o
|
||||
b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
|
||||
b_k_fp32 = tl.trans(b_k.to(tl.float32))
|
||||
b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
|
||||
b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
|
||||
b_dk += tl.trans(b_dk2)
|
||||
|
||||
# hidden state update
|
||||
b_dh_0o += tl.sum(b_do, axis=0)
|
||||
b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
|
||||
b_q_2o = b_q[None, :, :] * b_q[:, None, :]
|
||||
b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
|
||||
b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
|
||||
|
||||
if i_v == 0:
|
||||
dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
|
||||
dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
|
||||
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class FusedChunkBasedFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, scale=1):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
# assert d_head_qk == 16, "currently we do not support feature dim other than 16"
|
||||
d_head_v = v.shape[-1]
|
||||
|
||||
scale = scale
|
||||
BT = 16
|
||||
BK, BV = min(d_head_qk, 16), min(d_head_v, 32)
|
||||
BK, BV = max(BK, 16), max(BV, 16)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
|
||||
num_warps = 4
|
||||
|
||||
# the norm of o might explode, so we need to use float32 here
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, dtype=torch.float32)
|
||||
z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32)
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_chunk_based_fwd_kernel[grid](
|
||||
q, k, v, o, z,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
o = o.sum(0)
|
||||
z = z.sum(0)
|
||||
ctx.save_for_backward(q, k, v)
|
||||
ctx.scale = scale
|
||||
return o.to(q.dtype), z.to(z.dtype)
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do, dz):
|
||||
q, k, v = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = ctx.scale
|
||||
|
||||
BT = 16
|
||||
BK, BV = min(d_head_qk, 16), min(d_head_v, 32)
|
||||
BK, BV = max(BK, 16), max(BV, 16)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4
|
||||
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
|
||||
fused_chunk_based_bwd_kernel[grid](
|
||||
q, k, v, do, dz, dq, dk, dv,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
|
||||
|
||||
|
||||
triton_fused_chunk_based = FusedChunkBasedFunction.apply
|
||||
|
||||
|
||||
def fused_chunk_based(q, k, v, use_scale=True, use_normalize=True):
|
||||
assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
|
||||
if use_scale:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
else:
|
||||
scale = 1
|
||||
o, z = triton_fused_chunk_based(q, k, v, scale)
|
||||
if use_normalize:
|
||||
o = o / (z[..., None] + 1e-6)
|
||||
else:
|
||||
o = o
|
||||
|
||||
return o.to(q.dtype)
|
||||
132
finetune/lora/v6/fla/ops/based/naive.py
vendored
Normal file
132
finetune/lora/v6/fla/ops/based/naive.py
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from fla.ops.based.chunk_fuse import fused_chunk_based
|
||||
from fla.ops.based.parallel import parallel_based
|
||||
|
||||
|
||||
def naive_parallel_based(q, k, v, use_scale=True, use_norm=True):
|
||||
if use_scale:
|
||||
q = q * (q.shape[-1] ** -0.5)
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = 1 + attn + 1/2 * (attn ** 2)
|
||||
attn.masked_fill_(~torch.tril(torch.ones(
|
||||
q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
|
||||
o = attn @ v
|
||||
if use_norm:
|
||||
z = attn.sum(-1)
|
||||
return o / (z[..., None] + 1e-6)
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
def naive_chunk_based(q, k, v, chunk_size=256):
|
||||
q = q * (q.shape[-1] ** -0.5)
|
||||
|
||||
# compute normalizer.
|
||||
k_cumsum = torch.cumsum(k, dim=-2)
|
||||
kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
|
||||
# first
|
||||
z = (q * k_cumsum).sum(-1)
|
||||
# second order
|
||||
z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
|
||||
# zero-th order
|
||||
z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
|
||||
|
||||
# compute o
|
||||
# constant term
|
||||
_o = v.cumsum(-2)
|
||||
|
||||
q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
|
||||
|
||||
k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
|
||||
v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
|
||||
|
||||
intra_chunk_attn = q @ k.transpose(-2, -1)
|
||||
intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
|
||||
intra_chunk_attn.masked_fill_(
|
||||
~torch.tril(
|
||||
torch.ones(chunk_size, chunk_size,
|
||||
dtype=torch.bool, device=q.device),
|
||||
), 0)
|
||||
o = intra_chunk_attn @ v
|
||||
|
||||
# quadractic term
|
||||
kv = torch.einsum(
|
||||
'b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
|
||||
kv = kv.cumsum(2)
|
||||
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
|
||||
|
||||
o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
|
||||
|
||||
# linear term
|
||||
kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
|
||||
kv = kv.cumsum(2)
|
||||
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
|
||||
o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
|
||||
|
||||
o = rearrange(o, 'b h n c d -> b h (n c) d')
|
||||
o = o + _o
|
||||
return o / (z[..., None] + 1e-6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
B = 4
|
||||
H = 4
|
||||
L = 128
|
||||
# D = 15
|
||||
dtype = torch.float32
|
||||
q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
|
||||
k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
|
||||
v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True)
|
||||
|
||||
do = torch.randn_like(v).cuda()
|
||||
ref = naive_parallel_based(q, k, v, True, True)
|
||||
ref.backward(do, retain_graph=True)
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
|
||||
# tri = naive_chunk_based(q, k, v)
|
||||
# tri.backward(do, retain_graph=True)
|
||||
# tri_dq, q.grad = q.grad.clone(), None
|
||||
# tri_dk, k.grad = k.grad.clone(), None
|
||||
# tri_dv, v.grad = v.grad.clone(), None
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
|
||||
|
||||
tri = fused_chunk_based(q, k, v, True, True)
|
||||
tri.backward(do, retain_graph=True)
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
print((ref-tri).abs().max())
|
||||
print((ref_dq-tri_dq).abs().max())
|
||||
print((ref_dk-tri_dk).abs().max())
|
||||
print((ref_dv-tri_dv).abs().max())
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
|
||||
|
||||
tri = parallel_based(q, k, v, True, True)
|
||||
tri.backward(do, retain_graph=True)
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
|
||||
print((ref-tri).abs().max())
|
||||
print((ref_dq-tri_dq).abs().max())
|
||||
print((ref_dk-tri_dk).abs().max())
|
||||
print((ref_dv-tri_dv).abs().max())
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
|
||||
388
finetune/lora/v6/fla/ops/based/parallel.py
vendored
Normal file
388
finetune/lora/v6/fla/ops/based/parallel.py
vendored
Normal file
@@ -0,0 +1,388 @@
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# Based: An Educational and Effective Sequence Mixer
|
||||
# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
|
||||
|
||||
|
||||
@triton.jit
|
||||
def parallel_based_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
z, # normalizer [B, H, L]
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q
|
||||
BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
):
|
||||
# i_c: chunk index. used for sequence parallelism
|
||||
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
NV = tl.cdiv(DV, BV)
|
||||
i_k = i_kv // (NV)
|
||||
i_v = i_kv % (NV)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
|
||||
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))
|
||||
|
||||
# [BQ, BD] block Q, in the shared memory throughout the whole kernel
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
b_o = tl.zeros([BTL, BV], dtype=tl.float32)
|
||||
b_z = tl.zeros([BTL], dtype=tl.float32)
|
||||
|
||||
# Q block and K block have no overlap
|
||||
# no need for mask, thereby saving flops
|
||||
for _ in range(0, i_c * BTL, BTS):
|
||||
# [BK, BTS]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
|
||||
# [BTS, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
b_s = tl.dot(b_q, (b_k), allow_tf32=False)
|
||||
b_s = 1 + b_s + 0.5 * b_s * b_s
|
||||
b_z += tl.sum(b_s, axis=1)
|
||||
|
||||
# [BQ, BD]
|
||||
b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (0, BTS))
|
||||
p_v = tl.advance(p_v, (BTS, 0))
|
||||
|
||||
# # rescale interchunk output
|
||||
tl.debug_barrier()
|
||||
o_q = tl.arange(0, BTL)
|
||||
# # sync threads, easy for compiler to optimize
|
||||
# tl.debug_barrier()
|
||||
|
||||
o_k = tl.arange(0, BTS)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
|
||||
(s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
|
||||
# Q block and K block have overlap. masks required
|
||||
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
|
||||
# [BK, BTS]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BTS, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
m_s = o_q[:, None] >= o_k[None, :]
|
||||
b_s = tl.dot(b_q, b_k, allow_tf32=False)
|
||||
b_s = 1 + b_s + 0.5 * b_s * b_s
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
b_z += tl.sum(b_s, axis=1)
|
||||
# [BTL, BV]
|
||||
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
|
||||
|
||||
p_k = tl.advance(p_k, (0, BTS))
|
||||
p_v = tl.advance(p_v, (BTS, 0))
|
||||
o_k += BTS
|
||||
|
||||
p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
|
||||
p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_z, b_z.to(p_z.dtype.element_ty),
|
||||
mask=((i_c * BTL + tl.arange(0, BTL)) < T))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _parallel_based_bwd_dq(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale,
|
||||
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
|
||||
DK: tl.constexpr, DV: tl.constexpr,
|
||||
):
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
|
||||
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
|
||||
p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
|
||||
(s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))
|
||||
p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
|
||||
b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
|
||||
|
||||
for _ in range(0, i_c * BTL, BTS):
|
||||
# [BTS, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BV, BTS]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[:, None]
|
||||
else:
|
||||
b_ds = b_ds
|
||||
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
|
||||
# [BQ, BD]
|
||||
b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (BTS, 0))
|
||||
p_v = tl.advance(p_v, (0, BTS))
|
||||
|
||||
b_dq *= scale
|
||||
o_q = tl.arange(0, BTL)
|
||||
o_k = tl.arange(0, BTS)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
|
||||
(s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
|
||||
# Q block and K block have overlap. masks required
|
||||
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
|
||||
# [BTS, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BV, BTS]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
m_s = o_q[:, None] >= o_k[None, :]
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[:, None]
|
||||
else:
|
||||
b_ds = b_ds
|
||||
b_ds = tl.where(m_s, b_ds, 0) * scale
|
||||
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
# [BTL, BK]
|
||||
b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype),
|
||||
b_k, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (BTS, 0))
|
||||
p_v = tl.advance(p_v, (0, BTS))
|
||||
o_k += BTS
|
||||
p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _parallel_based_bwd_dkv(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale,
|
||||
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
|
||||
DK: tl.constexpr, DV: tl.constexpr,
|
||||
):
|
||||
# compute dk dv
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
|
||||
(i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
|
||||
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
|
||||
b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(
|
||||
p_v, boundary_check=(0, 1))
|
||||
b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
|
||||
[BTL, BV], dtype=tl.float32)
|
||||
|
||||
for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
|
||||
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
|
||||
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
|
||||
b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \
|
||||
scale # [BTL, BTS]
|
||||
b_s2 = 1 + b_s + 0.5 * b_s * b_s
|
||||
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
|
||||
b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[None, :] * scale
|
||||
else:
|
||||
b_ds = b_ds
|
||||
b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),
|
||||
tl.trans(b_q), allow_tf32=False)
|
||||
|
||||
tl.debug_barrier()
|
||||
o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
|
||||
for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
|
||||
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
|
||||
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
|
||||
# [BK, BQ]
|
||||
m_s = o_k[:, None] <= o_q[None, :]
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
|
||||
b_s2 = 1 + b_s + 0.5 * b_s * b_s
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
b_s2 = tl.where(m_s, b_s2, 0)
|
||||
|
||||
b_ds = tl.dot(b_v, b_do, allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[None, :]
|
||||
else:
|
||||
b_ds = b_ds
|
||||
b_ds = tl.where(m_s, b_ds, 0) * scale
|
||||
# [BK, BD]
|
||||
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
|
||||
b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype),
|
||||
tl.trans(b_q), allow_tf32=False)
|
||||
o_q += BTS
|
||||
|
||||
p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,
|
||||
(T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,
|
||||
(T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def parallel_based_bwd_kernel(
|
||||
q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale,
|
||||
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
|
||||
DK: tl.constexpr, DV: tl.constexpr,
|
||||
):
|
||||
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
NV = tl.cdiv(DV, BV)
|
||||
i_k = i_kv // (NV)
|
||||
i_v = i_kv % (NV)
|
||||
i_h = i_bh % H
|
||||
_parallel_based_bwd_dq(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV
|
||||
)
|
||||
tl.debug_barrier()
|
||||
_parallel_based_bwd_dkv(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV
|
||||
)
|
||||
|
||||
|
||||
class ParallelBasedFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, scale):
|
||||
BTL, BTS = 128, 32
|
||||
assert BTL % BTS == 0
|
||||
# assert q.shape[-1] % 16 == 0
|
||||
BK = min(128, triton.next_power_of_2(k.shape[-1]))
|
||||
BV = min(128, triton.next_power_of_2(v.shape[-1]))
|
||||
BK, BV = max(BK, 16), max(BV, 16)
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
num_stages = 2
|
||||
num_warps = 4
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
NV = triton.cdiv(d_head_v, BV)
|
||||
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
|
||||
|
||||
assert NK == 1, "will encounter some synchronization issue if not."
|
||||
|
||||
o = torch.empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, device=q.device)
|
||||
z = torch.empty(NK, batch_size, n_heads, seq_len,
|
||||
device=q.device)
|
||||
parallel_based_fwd_kernel[grid](
|
||||
q, k, v, o, z,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
ctx.save_for_backward(q, k, v)
|
||||
ctx.scale = scale
|
||||
return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, dz):
|
||||
q, k, v = ctx.saved_tensors
|
||||
scale = ctx.scale
|
||||
BTL, BTS = 64, 32
|
||||
assert BTL % BTS == 0
|
||||
BK = min(128, triton.next_power_of_2(k.shape[-1]))
|
||||
BV = min(128, triton.next_power_of_2(v.shape[-1]))
|
||||
BK, BV = max(BK, 16), max(BV, 16)
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
num_stages = 2
|
||||
num_warps = 4
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
NV = triton.cdiv(d_head_v, BV)
|
||||
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
|
||||
|
||||
assert NK == 1, "will encounter some synchronization issue if not"
|
||||
|
||||
dq = torch.empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=q.dtype, device=q.device)
|
||||
dk = torch.empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=q.dtype, device=q.device)
|
||||
dv = torch.empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, dtype=q.dtype, device=q.device)
|
||||
|
||||
parallel_based_bwd_kernel[grid](
|
||||
q, k, v, do, dz, dq, dk, dv,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
|
||||
|
||||
|
||||
triton_parallel_based = ParallelBasedFunction.apply
|
||||
|
||||
|
||||
def parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False):
|
||||
assert q.shape[-1] <= 128, "only support feature dim up to 128"
|
||||
if use_scale:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
else:
|
||||
scale = 1
|
||||
o, z = triton_parallel_based(q, k, v, scale)
|
||||
if return_both:
|
||||
return o, z
|
||||
if use_normalize:
|
||||
o = o / (z[..., None] + 1e-6)
|
||||
else:
|
||||
o = o
|
||||
return o.to(q.dtype)
|
||||
4
finetune/lora/v6/fla/ops/delta_rule/README.md
vendored
Normal file
4
finetune/lora/v6/fla/ops/delta_rule/README.md
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
- Delta Rule
|
||||
|
||||
The implementation of delta rule described in https://arxiv.org/abs/2102.11174
|
||||
|
||||
11
finetune/lora/v6/fla/ops/delta_rule/__init__.py
vendored
Normal file
11
finetune/lora/v6/fla/ops/delta_rule/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk_fuse import fused_chunk_delta_rule
|
||||
from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule
|
||||
from .chunk import chunk_delta_rule
|
||||
|
||||
__all__ = [
|
||||
'fused_chunk_delta_rule',
|
||||
'fused_recurrent_linear_attn_delta_rule',
|
||||
'chunk_delta_rule'
|
||||
]
|
||||
544
finetune/lora/v6/fla/ops/delta_rule/chunk.py
vendored
Normal file
544
finetune/lora/v6/fla/ops/delta_rule/chunk.py
vendored
Normal file
@@ -0,0 +1,544 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from fla.ops.utils import contiguous
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from fla.ops.delta_rule.wy_fast import fwd_recompute_w_u, fwd_prepare_wy_repr, bwd_prepare_wy_repr
|
||||
from fla.ops.delta_rule.chunk_fuse import fused_chunk_delta_rule_fwd, fused_chunk_delta_rule_bwd
|
||||
# from fla.ops.delta_rule.utils import bwd_prepare_wy_repr
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def fwd_prepare_dv_kernel(
|
||||
q,
|
||||
k,
|
||||
do,
|
||||
dv,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
T,
|
||||
K,
|
||||
V,
|
||||
scale,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_k.dtype)
|
||||
b_A += tl.dot(b_k, b_q, allow_tf32=False)
|
||||
|
||||
b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A , 0).to(do.dtype.element_ty)
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_dv = tl.dot(b_A, b_do, allow_tf32=False)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def fwd_prepare_dv(q, k, do, BT):
|
||||
dv = torch.empty_like(do)
|
||||
B, H, T, K, V = *k.shape, do.shape[-1]
|
||||
NT = triton.cdiv(T, BT)
|
||||
BK = min(triton.next_power_of_2(K), 64)
|
||||
BV = min(triton.next_power_of_2(V), 64)
|
||||
fwd_prepare_dv_kernel[(NT, B*H)](
|
||||
q, k, do, dv,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
do.stride(1), do.stride(2), do.stride(3),
|
||||
T, K, V, K**-0.5, BT, BK, BV
|
||||
)
|
||||
return dv
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_delta_rule_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
d,
|
||||
v_new,
|
||||
h,
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i_t in range(NT):
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
# since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
|
||||
for i_c in range(tl.cdiv(BT, BC)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
||||
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_d = tl.load(p_d, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)
|
||||
# [BK, BV]
|
||||
tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
|
||||
b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
|
||||
b_h += b_h_cumsum
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_linear_attn_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
o,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_s = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
b_o += tl.dot(b_q, b_h, allow_tf32=False)
|
||||
b_s += tl.dot(b_q, b_k, allow_tf32=False)
|
||||
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_delta_rule_bwd_kernel_dhu(
|
||||
q,
|
||||
k,
|
||||
d,
|
||||
do,
|
||||
dh,
|
||||
dv,
|
||||
dv2,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
|
||||
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_d = tl.load(p_d, boundary_check=(0, 1))
|
||||
# [BT, V]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BT]
|
||||
# b_s = tl.dot(b_k, b_q, allow_tf32=False)
|
||||
# b_s = tl.where(m_s, b_s, 0)
|
||||
# b_dv = tl.dot(b_s.to(b_do.dtype), b_do, allow_tf32=False) + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
|
||||
|
||||
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
||||
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
|
||||
p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
|
||||
tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
|
||||
b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False)
|
||||
b_dh += b_dh_tmp
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_delta_rule_bwd_kernel_dqkw(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
h,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
dw,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
n_bh = tl.num_programs(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
|
||||
b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dw = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
|
||||
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
||||
b_dw += tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
|
||||
|
||||
# [BT, BT]
|
||||
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
|
||||
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
|
||||
def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state):
|
||||
B, H, T, K, V = *k.shape, u.shape[-1]
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
assert BK <= 256, "current kernel does not support head dimension larger than 256."
|
||||
BV = 16 if BK > 128 else 32
|
||||
BV = 64 if BK <= 64 else BV
|
||||
BC = 16 if BK > 128 else 32
|
||||
BC = 64 if BK <= 64 else BC
|
||||
BC = min(BT, BC)
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
|
||||
|
||||
h = k.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
v_new = torch.empty_like(u)
|
||||
chunk_delta_rule_fwd_kernel_h[grid](
|
||||
k, u, w, v_new, h, initial_state, final_state,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
u.stride(1), u.stride(2), u.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None,
|
||||
)
|
||||
return h, v_new
|
||||
|
||||
|
||||
def chunk_bwd_dhu_fn(q, k, w, do, dv, BT):
|
||||
B, H, T, K, V = *q.shape, do.shape[-1]
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
assert BK <= 256, "current kernel does not support head dimension being larger than 256."
|
||||
BV = 16 if BK > 128 else 32
|
||||
BV = 64 if BK <= 64 else BV
|
||||
BC = 16 if BK > 128 else 32
|
||||
BC = 64 if BK <= 64 else BC
|
||||
BC = min(BT, BC)
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
|
||||
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
# dv_new = torch.empty_like(do)
|
||||
grid = (NK, NV, B * H)
|
||||
dv2 = torch.empty_like(dv)
|
||||
chunk_delta_rule_bwd_kernel_dhu[grid](
|
||||
q, k, w, do, dh, dv, dv2,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
do.stride(1), do.stride(2), do.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
K**-0.5,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,
|
||||
)
|
||||
return dh, dv2
|
||||
|
||||
|
||||
def chunk_fwd_o_fn(q, k, v_new, h, BT):
|
||||
B, H, T, K, V = *q.shape, v_new.shape[-1]
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
o = torch.empty_like(v_new)
|
||||
BK = min(triton.next_power_of_2(K), 64)
|
||||
BV = min(triton.next_power_of_2(K), 64)
|
||||
NV = triton.cdiv(V, BV)
|
||||
NT = triton.cdiv(T, BT)
|
||||
grid = (NV, NT, B * H)
|
||||
chunk_linear_attn_fwd_kernel_o[grid](
|
||||
q, k, v_new, h, o,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v_new.stride(1), v_new.stride(2), v_new.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
scale=K**-0.5,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
|
||||
def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT):
|
||||
B, H, T, K, V = *q.shape, v_new.shape[-1]
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
BK = min(triton.next_power_of_2(K), 64)
|
||||
BV = min(triton.next_power_of_2(V), 64)
|
||||
NV = triton.cdiv(V, BV)
|
||||
NT = triton.cdiv(T, BT)
|
||||
grid = (NV, NT, B * H)
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dw = torch.empty_like(w)
|
||||
chunk_delta_rule_bwd_kernel_dqkw[grid](
|
||||
q, k, v_new, w, h, do, dh, dq, dk, du, dw,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v_new.stride(1), v_new.stride(2), v_new.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale = K ** -0.5,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype)
|
||||
|
||||
|
||||
class ChunkDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1):
|
||||
### obtain WY representation. u is actually the new v.
|
||||
w, u, A = fwd_prepare_wy_repr(k, v, beta, BT)
|
||||
# ### forward_h
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
|
||||
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)
|
||||
## obtain output
|
||||
o = chunk_fwd_o_fn(q, k, v_new, h, BT)
|
||||
# save memory
|
||||
if checkpoint_level == 1:
|
||||
h, v_new = None, None
|
||||
ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state)
|
||||
ctx.BT = BT
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, d_ht=None):
|
||||
q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors
|
||||
scale = q.shape[-1] ** -0.5
|
||||
BT = ctx.BT
|
||||
w, u = fwd_recompute_w_u(k, v, beta, A, BT)
|
||||
# checkpont_level=1, recomputation.
|
||||
if h is None:
|
||||
h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None)
|
||||
dv = fwd_prepare_dv(q, k, do, BT)
|
||||
dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)
|
||||
dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)
|
||||
dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT)
|
||||
dk.add_(dk2)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None
|
||||
|
||||
def chunk_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
BT: int,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
):
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
419
finetune/lora/v6/fla/ops/delta_rule/chunk_fuse.py
vendored
Normal file
419
finetune/lora/v6/fla/ops/delta_rule/chunk_fuse.py
vendored
Normal file
@@ -0,0 +1,419 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8)
|
||||
],
|
||||
key=["BT", "BK"],
|
||||
)
|
||||
@triton.jit
|
||||
def fused_chunk_delta_rule_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_K]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
v_new,
|
||||
d, # decay [B, H, L, D_head_K]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
# [BT, BT]
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
# [BK, BV]
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
# make block pointers
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
|
||||
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_d = tl.load(p_d, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_k.dtype)
|
||||
|
||||
# [BT, BT]
|
||||
b_s = tl.dot(b_q, b_k, allow_tf32=False)
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
# [BT, BV]
|
||||
b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False)
|
||||
b_v = b_v - b_v_prime
|
||||
tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False)
|
||||
if CHECK and i == 0:
|
||||
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
|
||||
b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
|
||||
else:
|
||||
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
|
||||
b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
p_q = tl.advance(p_q, (BT, 0))
|
||||
p_k = tl.advance(p_k, (0, BT))
|
||||
p_v = tl.advance(p_v, (BT, 0))
|
||||
p_v_new = tl.advance(p_v_new, (BT, 0))
|
||||
p_o = tl.advance(p_o, (BT, 0))
|
||||
p_d = tl.advance(p_d, (BT, 0))
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def fused_chunk_delta_rule_bwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
d, # decay [B, H, L, D_head_K]
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
dd, # gradient of decay [NV, B, H, L, D_head_K]
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
# first reverse
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
m_s = o_i[:, None] <= o_i[None, :]
|
||||
for i in range(1, tl.cdiv(T, BT) + 1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
|
||||
p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
|
||||
# [DK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BT]
|
||||
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
|
||||
b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
|
||||
# [BT, BT]
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False)
|
||||
b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
|
||||
# [BT, DK]
|
||||
b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
|
||||
# [BT, DV]
|
||||
b_dv = tl.dot(b_s, b_do, allow_tf32=False)
|
||||
b_d = tl.load(p_d, boundary_check=(0, 1))
|
||||
if CHECK and i == 1:
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
|
||||
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
|
||||
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
|
||||
b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
|
||||
else:
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
|
||||
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
|
||||
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
|
||||
b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False)
|
||||
|
||||
tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# sync threads
|
||||
b_h = None
|
||||
tl.debug_barrier()
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
# [BV, BK]
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
NT = tl.cdiv(T, BT)
|
||||
for i in range(0, NT):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
|
||||
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [DV, BT]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BT]
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
|
||||
b_ds = tl.where(m_s, b_ds, 0)
|
||||
# [BT, DK]
|
||||
b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
|
||||
# [DV, DK]
|
||||
if CHECK and i == 0:
|
||||
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
|
||||
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
|
||||
else:
|
||||
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
|
||||
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
|
||||
b_dq *= scale
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if i < (NT - 1):
|
||||
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
||||
b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False)
|
||||
p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
|
||||
((i+1) * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = d_head_qk ** -0.5
|
||||
BT = BT
|
||||
# ctx.BT = BT
|
||||
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
assert NK == 1, 'NK should be 1'
|
||||
o = q.new_empty(batch_size, n_heads, seq_len, d_head_v)
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
|
||||
else:
|
||||
final_state = None
|
||||
CHECK = True
|
||||
# if version.parse(triton.__version__) < version.parse('2.2.0'):
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Triton<2.2.0 detected for running this kernel, "
|
||||
# "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
|
||||
# "that lead to significant precision loss. "
|
||||
# "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
|
||||
# "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
|
||||
# )
|
||||
# CHECK = True
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
v_new = torch.empty_like(v)
|
||||
fused_chunk_delta_rule_fwd_kernel[grid](
|
||||
q, k, v, v_new, d, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=output_final_state,
|
||||
CHECK=CHECK,
|
||||
)
|
||||
return o, v_new, CHECK, final_state
|
||||
|
||||
|
||||
def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = d_head_qk ** -0.5
|
||||
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
assert NK == 1
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_chunk_delta_rule_bwd_kernel[grid](
|
||||
q, k, v, d, do, dq, dk, dv, dd, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
CHECK=CHECK,
|
||||
# num_warps=num_warps,
|
||||
# num_stages=num_stages
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
dd = dd.sum(0)
|
||||
dd[:, :, 0:BT] = 0
|
||||
return dq, dk, dv, dd
|
||||
|
||||
class FusedChunkDeltaRuleFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0):
|
||||
# lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory.
|
||||
assert checkpoint_level in [0, 1]
|
||||
k_origin = k
|
||||
# k = _l2_norm_fwd(k_origin)
|
||||
k = k
|
||||
d, v_new = fwd_prepare_wy_repr(k, v, beta, BT)
|
||||
o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state)
|
||||
if checkpoint_level == 1:
|
||||
d, v_new = None, None
|
||||
ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state)
|
||||
ctx.CHECK = CHECK
|
||||
ctx.chunk_size = BT
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors
|
||||
chunk_size = ctx.chunk_size
|
||||
k = k_origin
|
||||
# k = _l2_norm_fwd(k_origin)
|
||||
if d is None:
|
||||
d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
|
||||
dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state)
|
||||
dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size)
|
||||
dk.add_(dk2)
|
||||
# dk = _l2_norm_bwd(k_origin, dk)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None
|
||||
|
||||
|
||||
def fused_chunk_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
BT: int,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
|
||||
|
||||
def delta_rule_recurrence(q, k, v, beta):
|
||||
b, h, l, d_k = q.shape
|
||||
d_v = v.shape[-1]
|
||||
o = torch.zeros_like(v)
|
||||
S = torch.zeros(b, h, d_k, d_v).to(v)
|
||||
q = q * (d_k ** -0.5)
|
||||
k = torch.nn.functional.normalize(k, p=2, dim=-1)
|
||||
for i in range(l):
|
||||
_k = k[:, :, i]
|
||||
_q = q[:, :, i]
|
||||
_v = v[:, :, i].clone()
|
||||
beta_i = beta[:, :, i]
|
||||
_v = _v - (S.clone() * _k[..., None]).sum(-2)
|
||||
_v = _v * beta_i[..., None]
|
||||
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
|
||||
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
|
||||
return o
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch.nn.functional as F
|
||||
seq_len = 128
|
||||
b = 2
|
||||
h = 4
|
||||
q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
|
||||
k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1)
|
||||
v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1)
|
||||
beta = torch.rand(b, h, seq_len).sigmoid()
|
||||
q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta))
|
||||
do = torch.rand_like(v)
|
||||
o2 = delta_rule_recurrence(q, k, v.clone(), beta)
|
||||
o2.backward(do, retain_graph=True)
|
||||
q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad
|
||||
q.grad = k.grad = v.grad = beta.grad = None
|
||||
o, _ = fused_chunk_delta_rule(q, k, v, beta, 32)
|
||||
o.backward(do, retain_graph=True)
|
||||
q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad
|
||||
q.grad = k.grad = v.grad = beta.grad = None
|
||||
print((o - o2).abs().max())
|
||||
print((q_grad - q_grad2).abs().max())
|
||||
print((k_grad - k_grad2).abs().max())
|
||||
print((v_grad - v_grad2).abs().max())
|
||||
print((beta_grad - beta_grad2).abs().max())
|
||||
92
finetune/lora/v6/fla/ops/delta_rule/naive.py
vendored
Normal file
92
finetune/lora/v6/fla/ops/delta_rule/naive.py
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def delta_rule_recurrence(q, k, v, beta):
|
||||
b, h, l, d_k = q.shape
|
||||
d_v = v.shape[-1]
|
||||
o = torch.zeros_like(v)
|
||||
S = torch.zeros(b, h, d_k, d_v).to(v)
|
||||
q = q * (d_k ** -0.5)
|
||||
for i in range(l):
|
||||
_k = k[:, :, i]
|
||||
_q = q[:, :, i]
|
||||
_v = v[:, :, i].clone()
|
||||
beta_i = beta[:, :, i]
|
||||
_v = _v - (S.clone() * _k[..., None]).sum(-2)
|
||||
_v = _v * beta_i[..., None]
|
||||
S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
|
||||
o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
|
||||
return o
|
||||
|
||||
|
||||
def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
|
||||
b, h, l, d_k = q.shape
|
||||
d_v = v.shape[-1]
|
||||
q = q * (d_k ** -0.5)
|
||||
v = v * beta[..., None]
|
||||
k_beta = k * beta[..., None]
|
||||
|
||||
assert l % chunk_size == 0
|
||||
|
||||
# note that diagonal is masked.
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
|
||||
q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
|
||||
attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
|
||||
|
||||
for i in range(1, chunk_size):
|
||||
attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
|
||||
|
||||
attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
|
||||
# u
|
||||
k_cumsum = attn @ v
|
||||
# w
|
||||
k_cumdecay = attn @ k_beta
|
||||
|
||||
v = k_cumsum
|
||||
S = k.new_zeros(b, h, d_k, d_v)
|
||||
o = torch.zeros_like(v)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
|
||||
for i in range(0, l // chunk_size):
|
||||
q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
|
||||
v_prime = k_cumdecay[:, :, i] @ S
|
||||
v_new = v_i - v_prime
|
||||
o_inter = q_i @ S
|
||||
o[:, :, i] = o_inter + attn @ v_new
|
||||
# chunk state update
|
||||
S = S + k_i.transpose(-1, -2) @ v_new
|
||||
|
||||
return rearrange(o, 'b h n c d -> b h (n c) d')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
B = 2
|
||||
H = 4
|
||||
L = 256
|
||||
DK = 128
|
||||
DV = 128
|
||||
q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True)
|
||||
k = (torch.randn(B, H, L, DK)).cuda()
|
||||
k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True)
|
||||
v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True)
|
||||
beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True)
|
||||
|
||||
o = delta_rule_recurrence(q, k, v, beta)
|
||||
do = torch.randn(B, H, L, DV).cuda()
|
||||
o.backward(do, retain_graph=True)
|
||||
q_grad, q.grad = q.grad, None
|
||||
k_grad, k.grad = k.grad, None
|
||||
v_grad, v.grad = v.grad, None
|
||||
beta_grad, beta.grad = beta.grad, None
|
||||
|
||||
o2 = delta_rule_chunkwise(q, k, v, beta)
|
||||
o2.backward(do)
|
||||
assert torch.allclose(o, o2, atol=1e-4), breakpoint()
|
||||
assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint()
|
||||
assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint()
|
||||
assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint()
|
||||
assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint()
|
||||
print("All passed!")
|
||||
312
finetune/lora/v6/fla/ops/delta_rule/recurrent_fuse.py
vendored
Normal file
312
finetune/lora/v6/fla/ops/delta_rule/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,312 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V].
|
||||
beta, # beta [B, H, L]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
initial_state,
|
||||
final_state, # final hidden state [B, H, D_head_K, D_head_V]
|
||||
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
||||
):
|
||||
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
p_beta = beta + i_bh * T
|
||||
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
|
||||
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
|
||||
mask_kv = mask_bk[None, :] & mask_bv[:, None]
|
||||
|
||||
h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
_v_minus = tl.sum(h * _k[None, :], axis=1)
|
||||
_v -= _v_minus
|
||||
_beta = tl.load(p_beta).to(tl.float32)
|
||||
# in-place overwrite
|
||||
tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv)
|
||||
_v *= _beta
|
||||
h += _k[None, :] * _v[:, None]
|
||||
_o = h * _q[None, :]
|
||||
_o = tl.sum(_o, axis=1)
|
||||
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
p_q += DK
|
||||
p_k += DK
|
||||
p_o += DV
|
||||
p_v += DV
|
||||
p_beta += 1
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final_s = final_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_recurrent_bwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
beta, # beta [B, H, L]
|
||||
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
dbeta, # gradient of beta [B, H, L]
|
||||
|
||||
# initial hidden state initialization [B, H, D_head_K, D_head_V]
|
||||
initial_state,
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < DK
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < DV
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
|
||||
p_beta = beta + i_bh * T + T - 1
|
||||
p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1
|
||||
|
||||
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
|
||||
BK + tl.arange(0, BK) + (T - 1) * DK
|
||||
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
|
||||
BV + tl.arange(0, BV) + (T - 1) * DV
|
||||
d_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
for _ in range(T):
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_beta = tl.load(p_beta).to(tl.float32)
|
||||
d_h += _q[:, None] * _do[None, :]
|
||||
d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1)
|
||||
d_v = tl.sum(d_h * _k[:, None], axis=0)
|
||||
|
||||
d_beta = tl.sum(d_v * _v)
|
||||
d_v = d_v * _beta
|
||||
|
||||
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||||
tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))
|
||||
|
||||
d_h -= _k[:, None] * d_v[None, :]
|
||||
|
||||
p_do -= DV
|
||||
p_q -= DK
|
||||
p_k -= DK
|
||||
p_v -= DV
|
||||
p_dk -= DK
|
||||
p_dv -= DV
|
||||
p_dbeta -= 1
|
||||
p_beta -= 1
|
||||
|
||||
tl.debug_barrier()
|
||||
|
||||
h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
p_beta = beta + i_bh * T
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV
|
||||
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
mask_kv = mask_bk[:, None] & mask_bv[None, :]
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[:, None]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[None, :])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for i in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
_beta = tl.load(p_beta).to(tl.float32)
|
||||
_v *= _beta
|
||||
|
||||
h += _k[:, None] * _v[None, :]
|
||||
_d_q = h * _do[None, :]
|
||||
d_q = tl.sum(_d_q, axis=1) * scale
|
||||
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
|
||||
|
||||
if i < T - 1:
|
||||
d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)
|
||||
d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)
|
||||
d_k -= tl.sum(d_v[None, :] * h, axis=1)
|
||||
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
|
||||
p_k += DK
|
||||
p_do += DV
|
||||
p_v += DV
|
||||
p_dk += DK
|
||||
p_dv += DV
|
||||
p_dq += DK
|
||||
p_beta += 1
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
|
||||
scale = d_head_qk ** -0.5
|
||||
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
|
||||
else:
|
||||
final_state = None
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_recurrent_fwd_kernel[grid](
|
||||
q, k, v, beta, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None
|
||||
)
|
||||
o = o.sum(0)
|
||||
ctx.save_for_backward(q, k, v, beta, initial_state)
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, beta, initial_state = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = d_head_qk ** -0.5
|
||||
BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 1
|
||||
num_warps = 2
|
||||
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
dbeta = q.new_empty(NV, batch_size, n_heads, seq_len)
|
||||
|
||||
fused_recurrent_bwd_kernel[grid](
|
||||
q, k, v, beta, do, dq, dk, dv, dbeta, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
dbeta = dbeta.sum(0)
|
||||
return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None
|
||||
|
||||
|
||||
def fused_recurrent_linear_attn_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
normalize: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
if beta is None:
|
||||
beta = torch.ones_like(q[..., 0])
|
||||
o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
297
finetune/lora/v6/fla/ops/delta_rule/utils.py
vendored
Normal file
297
finetune/lora/v6/fla/ops/delta_rule/utils.py
vendored
Normal file
@@ -0,0 +1,297 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
from fla.ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2
|
||||
|
||||
|
||||
|
||||
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
|
||||
# o: cumprod
|
||||
# o2: cumprodsum
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def fwd_prepare_wy_repr_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
o,
|
||||
o2,
|
||||
T,
|
||||
K,
|
||||
V,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
|
||||
p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
|
||||
p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
|
||||
mask_bt = (tl.arange(0, BT) + i_t * BT) < T
|
||||
mask_bk = tl.arange(0, BK) < K
|
||||
mask_bv = tl.arange(0, BV) < V
|
||||
mask_bk = mask_bk[None, :] & mask_bt[:, None]
|
||||
mask_bv = mask_bv[None, :] & mask_bt[:, None]
|
||||
# [BT, BK]
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0)
|
||||
# [BT,]
|
||||
b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32)
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0)
|
||||
b_v = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
# [BT, BK]
|
||||
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
|
||||
# [BT, BT]
|
||||
b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
|
||||
b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
|
||||
|
||||
for i in range(BT):
|
||||
mask = tl.arange(0, BT) == i
|
||||
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
|
||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
|
||||
b_A = b_A.to(b_k.dtype)
|
||||
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
|
||||
b_u = tl.dot(b_A, b_v, allow_tf32=False)
|
||||
|
||||
p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
|
||||
tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk)
|
||||
p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
|
||||
tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def bwd_prepare_wy_repr_kernel(
|
||||
k, v, beta,
|
||||
o, o2, do, do2,
|
||||
dk, dv, dbeta,
|
||||
NT, K, V, T,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
|
||||
p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
|
||||
p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
|
||||
|
||||
p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT)
|
||||
mask_bt = (tl.arange(0, BT) + i_t * BT) < T
|
||||
mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None]
|
||||
mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None]
|
||||
b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt)
|
||||
|
||||
b_beta = b_beta.to(tl.float32)
|
||||
A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None]
|
||||
A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0)
|
||||
b_do = tl.load(p_do, mask=mask_bk).to(tl.float32)
|
||||
b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32)
|
||||
dA = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
for i in range(BT-1, -1, -1):
|
||||
mask = tl.arange(0, BT) == i
|
||||
attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0)
|
||||
do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0)
|
||||
dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0)
|
||||
b_do = b_do - attn[:, None] * do_[None, :]
|
||||
b_dv = b_dv - attn[:, None] * dv_[None, :]
|
||||
tl.debug_barrier()
|
||||
p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
|
||||
b_v = tl.load(p_v, mask=mask_bv)
|
||||
b_dk += b_do * b_beta[:, None]
|
||||
b_dbeta = tl.sum(b_do * b_k, axis=1)
|
||||
b_dbeta += tl.sum(b_dv * b_v, axis=1)
|
||||
b_v = None
|
||||
|
||||
p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
|
||||
p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
|
||||
b_o = tl.load(p_o, mask=mask_bk)
|
||||
b_o2 = tl.load(p_o2, mask=mask_bv)
|
||||
|
||||
dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False)
|
||||
dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype),
|
||||
allow_tf32=False)
|
||||
dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0)
|
||||
b_dv *= b_beta[:, None]
|
||||
p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :]
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1)
|
||||
dA = dA * b_beta[:, None]
|
||||
b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False)
|
||||
b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False)
|
||||
p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :]
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT)
|
||||
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt)
|
||||
|
||||
|
||||
def fwd_prepare_wy_repr(k, v, beta, chunk_size):
|
||||
B, H, T, K, V = *k.shape, v.shape[-1]
|
||||
v_new = torch.empty_like(v)
|
||||
o_cumdecay = torch.empty_like(k)
|
||||
BT = chunk_size
|
||||
NT = triton.cdiv(T, BT)
|
||||
BK = triton.next_power_of_2(K)
|
||||
BV = triton.next_power_of_2(V)
|
||||
fwd_prepare_wy_repr_kernel[(NT, B*H)](
|
||||
k, v, beta, o_cumdecay, v_new,
|
||||
T, K, V, BT, BK, BV
|
||||
)
|
||||
return o_cumdecay, v_new
|
||||
|
||||
|
||||
def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size):
|
||||
b, h, l, d_k = do.shape
|
||||
d_v = v.shape[-1]
|
||||
BK = triton.next_power_of_2(d_k)
|
||||
BV = triton.next_power_of_2(d_v)
|
||||
c = chunk_size
|
||||
BK = d_k
|
||||
NT = triton.cdiv(l, c)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
dbeta = torch.zeros_like(beta)
|
||||
bwd_prepare_wy_repr_kernel[(NT, b*h)](
|
||||
k, v, beta,
|
||||
o_cumdecay, v_new, do, do2,
|
||||
dk, dv, dbeta,
|
||||
NT, d_k, d_v, l, chunk_size, BK, BV
|
||||
)
|
||||
return dk, dv, dbeta
|
||||
|
||||
class WYRepresentationPrepration(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, k, v, beta, chunk_size):
|
||||
o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size)
|
||||
ctx.chunk_size = chunk_size
|
||||
ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new)
|
||||
return o_cumdecay, v_new
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do, do2):
|
||||
k, v, beta, o_cumdecay, v_new = ctx.saved_tensors
|
||||
dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size)
|
||||
return dk, dv, dbeta, None
|
||||
|
||||
prepare_wy_repr = WYRepresentationPrepration.apply
|
||||
|
||||
|
||||
def naive(k, v, beta, chunk_size):
|
||||
l_org = k.shape[2]
|
||||
l_new = triton.next_power_of_2(l_org)
|
||||
# pad k, v, beta
|
||||
k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
|
||||
v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
|
||||
beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
|
||||
|
||||
k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
|
||||
# k = torch.nn.functional.normalize(k, dim=-1, p=2)
|
||||
beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
|
||||
k_beta = k * beta[..., None]
|
||||
v = v * beta[..., None]
|
||||
attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
|
||||
attn = attn * beta[..., None]
|
||||
x = attn @ v
|
||||
|
||||
o = torch.zeros_like(k)
|
||||
o2 = torch.zeros_like(v)
|
||||
|
||||
o[..., 0, :] = k_beta[..., 0, :].clone()
|
||||
o2[..., 0, :] = x[..., 0, :].clone()
|
||||
for i in range(1, chunk_size):
|
||||
o_i = (o[..., :i, :]).clone()
|
||||
o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
|
||||
o2_i = (o2[..., :i, :]).clone()
|
||||
o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
|
||||
return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
seq_len = 2048
|
||||
b = 4
|
||||
h = 8
|
||||
k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2)
|
||||
v = torch.randn(b, h, seq_len, 256)
|
||||
beta = torch.rand(b, h, seq_len).sigmoid()
|
||||
require_grad = True
|
||||
k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
|
||||
do = torch.rand_like(k)
|
||||
do2 = torch.rand_like(v)
|
||||
|
||||
print("Start warmup.")
|
||||
o1, o2 = prepare_wy_repr(k, v, beta, 32)
|
||||
# (o1 * do + o2 * do2).sum().backward()
|
||||
o3, o4 = prepare_wy_repr2(k, v, beta, 32)
|
||||
# (o1 * do + o2 * do2).sum().backward()
|
||||
print((o1 - o3).abs().max())
|
||||
print((o2 - o4).abs().max())
|
||||
|
||||
|
||||
for i in range(30):
|
||||
o1, o2 = prepare_wy_repr(k, v, beta, 32)
|
||||
(o1 * do + o2 * do2).sum().backward()
|
||||
o1, o2 = prepare_wy_repr2(k, v, beta, 32)
|
||||
(o1 * do + o2 * do2).sum().backward()
|
||||
|
||||
print("Done warmup.")
|
||||
|
||||
import time
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
for i in range(200):
|
||||
o1, o2 = prepare_wy_repr(k, v, beta, 64)
|
||||
(o1 * do + o2 * do2).sum().backward()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
print(time.time() - start)
|
||||
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
for i in range(200):
|
||||
o1, o2 = prepare_wy_repr2(k, v, beta, 64)
|
||||
(o1 * do + o2 * do2).sum().backward()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
print(time.time() - start)
|
||||
|
||||
|
||||
|
||||
401
finetune/lora/v6/fla/ops/delta_rule/wy_fast.py
vendored
Normal file
401
finetune/lora/v6/fla/ops/delta_rule/wy_fast.py
vendored
Normal file
@@ -0,0 +1,401 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009
|
||||
# o: cumprod
|
||||
# o2: cumprodsum
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def fwd_prepare_wy_repr_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
T,
|
||||
K,
|
||||
V,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
|
||||
b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)
|
||||
|
||||
b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
|
||||
|
||||
for i in range(1, BT):
|
||||
mask = tl.arange(0, BT) == i
|
||||
b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
|
||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
|
||||
b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
|
||||
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
b_A = b_A.to(k.dtype.element_ty)
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
|
||||
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
|
||||
p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def fwd_recompute_w_u_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
T,
|
||||
K,
|
||||
V,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
|
||||
b_w = tl.dot(b_A, b_kb, allow_tf32=False)
|
||||
p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["BT", "BK", "BV"],
|
||||
)
|
||||
@triton.jit
|
||||
def bwd_prepare_wy_repr_kernel(
|
||||
k, v, beta, A,
|
||||
dw, du,
|
||||
dk, dv, dbeta,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
T,
|
||||
K,
|
||||
V,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty)
|
||||
|
||||
b_dbeta = tl.zeros([BT], dtype=tl.float32)
|
||||
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_du = tl.load(p_du, boundary_check=(0, 1))
|
||||
b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
|
||||
b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)
|
||||
b_dv = b_dv_beta * b_beta[:, None]
|
||||
b_dbeta += tl.sum(b_dv_beta * b_v, 1)
|
||||
# store
|
||||
p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
tl.debug_barrier()
|
||||
b_A2 = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
|
||||
b_dw = tl.load(p_dw, boundary_check=(0, 1))
|
||||
b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
|
||||
b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False)
|
||||
b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)
|
||||
b_dk = b_dk_beta * b_beta[:, None]
|
||||
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
|
||||
# store
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :])
|
||||
b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0)
|
||||
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
|
||||
tl.debug_barrier()
|
||||
|
||||
for i in range(BT-1, 0, -1):
|
||||
mask = tl.arange(0, BT) == i
|
||||
b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0)
|
||||
b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
|
||||
b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1)
|
||||
b_dA = tl.where(mask[:, None], b_da2, b_dA)
|
||||
b_dA += b_da[None, :] * b_a[:, None]
|
||||
|
||||
b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
|
||||
tl.debug_barrier()
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_dk = tl.load(p_dk, boundary_check=(0, 1))
|
||||
b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
|
||||
|
||||
b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
|
||||
b_dbeta += tl.sum(b_dk_beta * b_k, 1)
|
||||
b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
|
||||
b_dk += b_dk_beta * b_beta[:, None]
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
||||
tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty),boundary_check=(0,))
|
||||
|
||||
|
||||
def fwd_prepare_wy_repr(k, v, beta, BT):
|
||||
B, H, T, K, V = *k.shape, v.shape[-1]
|
||||
u = torch.empty_like(v)
|
||||
w = torch.empty_like(k)
|
||||
NT = triton.cdiv(T, BT)
|
||||
BK = min(triton.next_power_of_2(K), 64)
|
||||
BV = min(triton.next_power_of_2(V), 64)
|
||||
A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype)
|
||||
fwd_prepare_wy_repr_kernel[(NT, B*H)](
|
||||
k, v, beta, w, u, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
T, K, V, BT, BK, BV
|
||||
)
|
||||
return w, u, A
|
||||
|
||||
|
||||
|
||||
def fwd_recompute_w_u(k, v, beta, A, BT):
|
||||
B, H, T, K, V = *k.shape, v.shape[-1]
|
||||
u = torch.empty_like(v)
|
||||
w = torch.empty_like(k)
|
||||
NT = triton.cdiv(T, BT)
|
||||
BK = min(triton.next_power_of_2(K), 64)
|
||||
BV = min(triton.next_power_of_2(V), 64)
|
||||
fwd_recompute_w_u_kernel[(NT, B*H)](
|
||||
k, v, beta, w, u, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
T, K, V, BT, BK, BV
|
||||
)
|
||||
return w, u
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT):
|
||||
B, H, T, K, V = *k.shape, v.shape[-1]
|
||||
|
||||
NT = triton.cdiv(T, BT)
|
||||
BK = min(triton.next_power_of_2(K), 64)
|
||||
BV = min(triton.next_power_of_2(V), 64)
|
||||
NT = triton.cdiv(T, BT)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v).contiguous()
|
||||
dbeta = torch.zeros_like(beta)
|
||||
|
||||
bwd_prepare_wy_repr_kernel[(NT, B*H)](
|
||||
k, v, beta, A,
|
||||
dw, du,
|
||||
dk, dv, dbeta,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
T, K, V, BT, BK, BV
|
||||
)
|
||||
return dk, dv, dbeta
|
||||
|
||||
|
||||
class WYRepresentationPrepration(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, k, v, beta, chunk_size):
|
||||
ctx.BT = chunk_size
|
||||
w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT)
|
||||
ctx.save_for_backward(k, v, beta, A)
|
||||
return w, u
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, dw, du):
|
||||
k, v, beta, A = ctx.saved_tensors
|
||||
BT = ctx.BT
|
||||
dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT)
|
||||
return dk, dv, dbeta, None
|
||||
|
||||
|
||||
|
||||
|
||||
prepare_wy_repr = WYRepresentationPrepration.apply
|
||||
|
||||
def naive(k, v, beta, chunk_size):
|
||||
l_org = k.shape[2]
|
||||
l_new = triton.next_power_of_2(l_org)
|
||||
# pad k, v, beta
|
||||
k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2)
|
||||
v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2)
|
||||
beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2)
|
||||
|
||||
k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v))
|
||||
# k = torch.nn.functional.normalize(k, dim=-1, p=2)
|
||||
beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0)
|
||||
k_beta = k * beta[..., None]
|
||||
v = v * beta[..., None]
|
||||
attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0)
|
||||
attn = attn * beta[..., None]
|
||||
x = attn @ v
|
||||
|
||||
o = torch.zeros_like(k)
|
||||
o2 = torch.zeros_like(v)
|
||||
|
||||
o[..., 0, :] = k_beta[..., 0, :].clone()
|
||||
o2[..., 0, :] = x[..., 0, :].clone()
|
||||
for i in range(1, chunk_size):
|
||||
o_i = (o[..., :i, :]).clone()
|
||||
o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :]
|
||||
o2_i = (o2[..., :i, :]).clone()
|
||||
o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :]
|
||||
return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_default_dtype(torch.float32)
|
||||
seq_len = 1024
|
||||
b = 4
|
||||
h = 4
|
||||
k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)
|
||||
v = torch.randn(b, h, seq_len, 128)
|
||||
beta = torch.rand(b, h, seq_len).sigmoid()
|
||||
# beta = torch.ones(b, h, seq_len)
|
||||
require_grad = True
|
||||
|
||||
k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta))
|
||||
do = torch.rand_like(k)
|
||||
do2 = torch.rand_like(v)
|
||||
|
||||
o1, o2 = naive(k.clone(), v.clone(), beta.clone(), 64)
|
||||
if require_grad:
|
||||
o1.backward(do, retain_graph=True)
|
||||
o2.backward(do2, retain_graph=True)
|
||||
|
||||
k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad
|
||||
k.grad = v.grad = beta.grad = None
|
||||
|
||||
o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone())
|
||||
print((o1-o3).abs().max())
|
||||
print((o2-o4).abs().max())
|
||||
|
||||
if require_grad:
|
||||
o3.backward(do, retain_graph=True)
|
||||
o4.backward(do2, retain_graph=True)
|
||||
k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad
|
||||
print((k_grad2-k_grad).abs().max())
|
||||
print((v_grad2-v_grad).abs().max())
|
||||
print((beta_grad2-beta_grad).abs().max())
|
||||
breakpoint()
|
||||
|
||||
11
finetune/lora/v6/fla/ops/gla/__init__.py
vendored
Normal file
11
finetune/lora/v6/fla/ops/gla/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_gla
|
||||
from .chunk_fuse import fused_chunk_gla
|
||||
from .recurrent_fuse import fused_recurrent_gla
|
||||
|
||||
__all__ = [
|
||||
'chunk_gla',
|
||||
'fused_chunk_gla',
|
||||
'fused_recurrent_gla'
|
||||
]
|
||||
734
finetune/lora/v6/fla/ops/gla/chunk.py
vendored
Normal file
734
finetune/lora/v6/fla/ops/gla/chunk.py
vendored
Normal file
@@ -0,0 +1,734 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.ops.utils import chunk_reversed_cumsum_fwd
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BS': 16}, num_warps=2),
|
||||
triton.Config({'BS': 16}, num_warps=4),
|
||||
triton.Config({'BS': 16}, num_warps=8),
|
||||
triton.Config({'BS': 32}, num_warps=2),
|
||||
triton.Config({'BS': 32}, num_warps=4),
|
||||
triton.Config({'BS': 32}, num_warps=8),
|
||||
triton.Config({'BS': 64}, num_warps=2),
|
||||
triton.Config({'BS': 64}, num_warps=4),
|
||||
triton.Config({'BS': 64}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_gla_fwd_kernel_cum(
|
||||
s,
|
||||
o,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
for i_t in range(NT):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
if i_t < NT - 1:
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
else:
|
||||
b_gn = tl.min(b_g, axis=1)
|
||||
b_h *= tl.exp(b_gn)[:, None]
|
||||
b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
|
||||
b_h += tl.dot(b_k, b_v, allow_tf32=False)
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_fwd_kernel_intra(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
A,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
NC: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
|
||||
n_bh = tl.num_programs(2)
|
||||
|
||||
if i_i > i_j:
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)
|
||||
# [BK, BC]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
|
||||
# [BC, BC]
|
||||
b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
|
||||
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
|
||||
elif i_i == i_j:
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
|
||||
o_i = tl.arange(0, BC)
|
||||
o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
|
||||
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
|
||||
for j in range(0, BC):
|
||||
# [BK,]
|
||||
b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
|
||||
b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
|
||||
# [BC,]
|
||||
b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)
|
||||
b_A = tl.where(o_i >= j, b_A, 0.)
|
||||
tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
|
||||
|
||||
p_k = tl.advance(p_k, (K,))
|
||||
p_gk = tl.advance(p_gk, (K,))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_fwd_kernel_inter(
|
||||
q,
|
||||
v,
|
||||
g,
|
||||
h,
|
||||
o,
|
||||
A,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# works but dkw, owing to divine benevolence
|
||||
# [BT, BV]
|
||||
if i_k >= 0:
|
||||
b_o += tl.dot(b_qg, b_h, allow_tf32=False)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_o += tl.dot(b_A, b_v, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_bwd_kernel_dh(
|
||||
q,
|
||||
g,
|
||||
do,
|
||||
dh,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BK, BV]
|
||||
b_dh *= tl.exp(b_gn)[:, None]
|
||||
# [BK, BT]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
b_q = (b_q * tl.exp(b_g)).to(b_q.dtype)
|
||||
|
||||
# [BK, BV]
|
||||
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_bwd_kernel_inter(
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
A,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
dA,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
n_bh = tl.num_programs(2)
|
||||
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
|
||||
|
||||
# [BT, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
|
||||
b_k = (b_k * b_gn).to(b_k.dtype)
|
||||
# [BT, BT]
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BV]
|
||||
b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
|
||||
if i_k == 0:
|
||||
b_dv += tl.dot(b_A, b_do, allow_tf32=False)
|
||||
b_do = (b_do * scale).to(b_do.dtype)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
b_dq = b_dq * tl.exp(b_gk)
|
||||
b_dk = b_dk * b_gn
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
# [BT, BT]
|
||||
b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
|
||||
if i_k == 0:
|
||||
tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_bwd_kernel_intra(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
dA,
|
||||
dq,
|
||||
dk,
|
||||
dg,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
NC: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_t, i_i = i_c // NC, i_c % NC
|
||||
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BC, BK]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
b_dq = tl.zeros([BC, BK], dtype=tl.float32)
|
||||
for i_j in range(0, i_i):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||
# [BC, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
|
||||
# [BC, BC]
|
||||
b_dA = tl.load(p_dA, boundary_check=(0, 1))
|
||||
# [BC, BK]
|
||||
b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
|
||||
b_dq *= tl.exp(b_g - b_gn[None, :])
|
||||
|
||||
o_i = tl.arange(0, BC)
|
||||
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
|
||||
m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
|
||||
for j in range(0, BC):
|
||||
p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
|
||||
p_gkj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BC,]
|
||||
b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
|
||||
# [BK,]
|
||||
b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
|
||||
b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32)
|
||||
# [BC, BK]
|
||||
m_i = o_i[:, None] >= j
|
||||
# [BC, BK]
|
||||
b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.)
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
|
||||
b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
tl.debug_barrier()
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BC, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_dk = tl.zeros([BC, BK], dtype=tl.float32)
|
||||
for i_j in range(i_i + 1, NC):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype)
|
||||
# [BC, BC]
|
||||
b_dA = tl.load(p_dA, boundary_check=(0, 1))
|
||||
# [BC, BK]
|
||||
b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
|
||||
b_dk *= tl.exp(b_gn[None, :] - b_gk)
|
||||
|
||||
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
|
||||
for j in range(0, BC):
|
||||
p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
|
||||
p_gqj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BC,]
|
||||
b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
|
||||
# [BK,]
|
||||
b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
|
||||
b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
|
||||
# [BC, BK]
|
||||
m_i = o_i[:, None] <= j
|
||||
b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
|
||||
b_dg = b_q * b_dq - b_k * b_dk
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class ChunkGLAFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level):
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT, BC = 64, 16
|
||||
BK = min(64, triton.next_power_of_2(K))
|
||||
BV = min(64, triton.next_power_of_2(V))
|
||||
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
|
||||
NK = triton.cdiv(K, BK)
|
||||
NV = triton.cdiv(V, BV)
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
num_stages = 1
|
||||
|
||||
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NV, NK, B * H)
|
||||
chunk_gla_fwd_kernel_h[grid](
|
||||
k, v, g, h, h0, ht,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
STORE_FINAL_STATE=ht is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return h
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float)
|
||||
|
||||
g_org, g = g, torch.empty_like(g, dtype=torch.float)
|
||||
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
|
||||
# keep cummulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_gla_fwd_kernel_cum[grid](
|
||||
g_org, g,
|
||||
g.stride(1), g.stride(2), g.stride(3),
|
||||
T=T, S=K, BT=BT
|
||||
)
|
||||
h = fwd_inner(
|
||||
q=q, k=k, v=v, g=g,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
h0=initial_state if initial_state is not None else None,
|
||||
ht=final_state if final_state is not None else None
|
||||
)
|
||||
A = q.new_zeros(NK, B, H, T, BT)
|
||||
grid = (NK, NT * NC * NC, B * H)
|
||||
chunk_gla_fwd_kernel_intra[grid](
|
||||
q, k, g, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
scale,
|
||||
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
A = A.sum(0, dtype=A.dtype)
|
||||
o = torch.empty_like(v)
|
||||
grid = (NV, NT, B * H)
|
||||
chunk_gla_fwd_kernel_inter[grid](
|
||||
q, v, g, h, o, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
if checkpoint_level >= 1:
|
||||
del g
|
||||
g = g_org
|
||||
if checkpoint_level > 1:
|
||||
del h
|
||||
h, initial_state = None, None
|
||||
|
||||
ctx.save_for_backward(q, k, v, g, h, initial_state, A)
|
||||
ctx.BT = BT
|
||||
ctx.scale = scale
|
||||
ctx.checkpoint_level = checkpoint_level
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, dht=None):
|
||||
q, k, v, g, h, initial_state, A = ctx.saved_tensors
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT, BC = ctx.BT, 16
|
||||
BK = min(64, triton.next_power_of_2(K))
|
||||
BV = min(64, triton.next_power_of_2(V))
|
||||
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
|
||||
NK = triton.cdiv(K, BK)
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
num_stages = 1
|
||||
|
||||
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NV, NK, B * H)
|
||||
chunk_gla_fwd_kernel_h[grid](
|
||||
k, v, g, h, h0, ht,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
STORE_FINAL_STATE=ht is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return h
|
||||
|
||||
def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, NT, scale):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_gla_bwd_kernel_dh[grid](
|
||||
q, g, do, dh,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
do.stride(1), do.stride(2), do.stride(3),
|
||||
dh.stride(1), dh.stride(2), dh.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return dh
|
||||
|
||||
if ctx.checkpoint_level >= 1:
|
||||
# save the original g and compute its fp32 cumsum during the backward pass for memory consideration
|
||||
g_org, g = g, torch.zeros_like(g, dtype=torch.float)
|
||||
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
|
||||
# keep cummulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_gla_fwd_kernel_cum[grid](
|
||||
g_org, g,
|
||||
g.stride(1), g.stride(2), g.stride(3),
|
||||
T=T, S=K, BT=BT
|
||||
)
|
||||
|
||||
# rerun the forward pass to get h if checkpoint_level >= 1
|
||||
if ctx.checkpoint_level > 1:
|
||||
h = fwd_inner(
|
||||
q=q, k=k, v=v, g=g,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
h0=initial_state if initial_state is not None else None,
|
||||
ht=None
|
||||
)
|
||||
|
||||
scale = ctx.scale
|
||||
dh = bwd_inner(
|
||||
q, g, do,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
scale=scale
|
||||
)
|
||||
dq = torch.empty_like(q, dtype=torch.float)
|
||||
dk = torch.empty_like(k, dtype=torch.float)
|
||||
dg = torch.empty_like(k, dtype=torch.float)
|
||||
dv = v.new_empty(NK, *v.shape)
|
||||
dA = q.new_zeros(B, H, T, BT)
|
||||
grid = (NK, NT, B * H)
|
||||
chunk_gla_bwd_kernel_inter[grid](
|
||||
k, v, h, g, A, do, dh, dq, dk, dv, dA,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dv = dv.sum(0, dtype=dv.dtype)
|
||||
grid = (NK, NT * NC, B * H)
|
||||
chunk_gla_bwd_kernel_intra[grid](
|
||||
q, k, g, dA, dq, dk, dg,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
dq = dq.to(q.dtype)
|
||||
dk = dk.to(q.dtype)
|
||||
# reversed cumsum, equivalent to:
|
||||
#
|
||||
# def reversed_cumsum(x, dim=-1):
|
||||
# c = x.cumsum(dim)
|
||||
# return x + c.index_select(dim, x.new_tensor([c.shape[dim]-1], dtype=torch.long)) - c
|
||||
dg = chunk_reversed_cumsum_fwd(dg).to(k.dtype)
|
||||
return dq, dk, dv, dg, None, None, None, None
|
||||
|
||||
|
||||
def chunk_gla(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
scale: Optional[int] = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
checkpoint_level: Optional[int] = 2
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `(B, H, T, K)`
|
||||
k (torch.Tensor):
|
||||
keys of shape `(B, H, T, K)`
|
||||
v (torch.Tensor):
|
||||
values of shape `(B, H, T, V)`
|
||||
g (torch.Tensor):
|
||||
Forget gates of shape `(B, H, T, K)` applied to keys.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the GLA attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `(B, H, K, V)`. Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
|
||||
checkpoint_level (Optional[int]):
|
||||
Checkpointing level; higher values will save more memories and do more recomputations during backward.
|
||||
Default: `0`:
|
||||
- Level `0`: no memory saved, no recomputation.
|
||||
- Level `1`: recompute the fp32 cumulative values during backward.
|
||||
- Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
|
||||
"""
|
||||
assert checkpoint_level in [0, 1, 2]
|
||||
if scale is None:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)
|
||||
return o, final_state
|
||||
548
finetune/lora/v6/fla/ops/gla/chunk_fuse.py
vendored
Normal file
548
finetune/lora/v6/fla/ops/gla/chunk_fuse.py
vendored
Normal file
@@ -0,0 +1,548 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Songlin Yang
|
||||
# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.ops.gla.chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum,
|
||||
prepare_qg_kg)
|
||||
from fla.utils import contiguous
|
||||
|
||||
inv_ln2 = 1.44269504
|
||||
|
||||
@triton.jit
|
||||
def fused_chunk_gla_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_K]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
g, # cumulative sum of log decay [B, H, L, D_head_K]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
# make block pointers
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
|
||||
p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
|
||||
if CHECK and i == 0:
|
||||
b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
|
||||
b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
|
||||
else:
|
||||
b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
|
||||
b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
|
||||
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
p_q = tl.advance(p_q, (BT, 0))
|
||||
p_k = tl.advance(p_k, (0, BT))
|
||||
p_v = tl.advance(p_v, (BT, 0))
|
||||
p_o = tl.advance(p_o, (BT, 0))
|
||||
p_db += BT * DK
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_chunk_gla_bwd_kernel(
|
||||
q, k, v, g,
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
# [BV, BK]
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2
|
||||
d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
|
||||
|
||||
# [DV, BT]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [DV, DK]
|
||||
if CHECK and i == 0:
|
||||
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
|
||||
b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
|
||||
else:
|
||||
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
|
||||
b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
|
||||
b_dq *= scale
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# sync threads
|
||||
b_h = None
|
||||
tl.debug_barrier()
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
# cum = tl.zeros([BK], dtype=tl.float32)
|
||||
for i in range(1, tl.cdiv(T, BT) + 1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
# [DK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)
|
||||
|
||||
# inter-chunk
|
||||
# [DK, DV]
|
||||
if CHECK and i == 1:
|
||||
b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
|
||||
b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
|
||||
b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
|
||||
else:
|
||||
b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
|
||||
b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
|
||||
b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
|
||||
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fwd_inner_chunk(
|
||||
q, k, g, A,
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
):
|
||||
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
|
||||
|
||||
for i in range(BT):
|
||||
_q = tl.load(p_q, mask=mask, other=0) * scale
|
||||
gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
|
||||
s = _q[None, :] * b_k * tl.math.exp2(gq[None, :] - b_g)
|
||||
score = tl.sum(s, axis=1)
|
||||
score = tl.where(o_i <= i, score, 0)
|
||||
tl.store(p_A, score.to(p_A.dtype.element_ty))
|
||||
p_q += DK
|
||||
p_gq += DK
|
||||
p_A += BT
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_inner_chunk(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
dA,
|
||||
dq,
|
||||
dk,
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
|
||||
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
|
||||
for i in range(BT):
|
||||
_q = tl.load(p_q, mask=mask, other=0)
|
||||
gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
|
||||
score = tl.math.exp2(gq[None, :] - b_g)
|
||||
score = tl.where(o_i[:, None] <= i, score, 0)
|
||||
_dA = tl.load(p_dA)
|
||||
_dA = tl.where(o_i <= i, _dA, 0)
|
||||
b_dk += (_dA[:, None] * score * _q[None, :])
|
||||
b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0)
|
||||
tl.store(p_dq, b_dq, mask=mask)
|
||||
p_q += DK
|
||||
p_dq += DK
|
||||
p_gq += DK
|
||||
p_dA += BT
|
||||
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class FusedChunkGLAFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
|
||||
ctx.g_dtype = g.dtype
|
||||
g_original = g
|
||||
# cumulative decay should be in float32, otherwise the err will be accumulated and amplified.
|
||||
g = torch.empty_like(g, dtype=torch.float32)
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
ctx.scale = scale
|
||||
|
||||
# inter-chunk
|
||||
BT = 16 # chunk_size
|
||||
BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
|
||||
num_stages = 1
|
||||
num_warps = 2
|
||||
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
q_g = torch.empty_like(q)
|
||||
k_g = torch.empty_like(k)
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
fwd_decay_cumsum[grid](
|
||||
g_original,
|
||||
g,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
|
||||
)
|
||||
prepare_qg_kg[grid](
|
||||
q, k, g, q_g, k_g,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
|
||||
)
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float, requires_grad=False)
|
||||
else:
|
||||
final_state = None
|
||||
# the bug still exists even for Triton 2.2 on H100 GPUs
|
||||
# so we always enable initial checks
|
||||
CHECK = True
|
||||
if version.parse(triton.__version__) < version.parse('2.2.0'):
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Triton<2.2.0 detected for running this kernel, "
|
||||
"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
|
||||
"that lead to significant precision loss. "
|
||||
"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
|
||||
"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
|
||||
)
|
||||
CHECK = True
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_chunk_gla_fwd_kernel[grid](
|
||||
q_g, k_g, v, g, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=output_final_state,
|
||||
CHECK=CHECK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
|
||||
# intra-chunk
|
||||
chunk_size = 16
|
||||
num_chunk = seq_len // chunk_size
|
||||
v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
|
||||
BK = min(d_head_qk, 64)
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
A = q.new_empty(NK, batch_size, n_heads, triton.cdiv(seq_len, BT), BT, BT)
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
fwd_inner_chunk[grid](
|
||||
q, k, g, A,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_stages=3,
|
||||
num_warps=4
|
||||
)
|
||||
A = A.sum(0)
|
||||
o2 = A @ v2
|
||||
o2 = rearrange(o2, 'b h n c d -> b h (n c) d')
|
||||
# combine inner and inter
|
||||
o.add_(o2)
|
||||
ctx.save_for_backward(q, k, v, g_original, A, initial_state)
|
||||
ctx.CHECK = CHECK
|
||||
return o.to(v), final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, g_origin, A, initial_state = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = ctx.scale
|
||||
|
||||
# recomputation
|
||||
# inter-chunk
|
||||
BT = 16 # chunk_size
|
||||
g = torch.empty_like(g_origin, dtype=torch.float32)
|
||||
BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
q_g = torch.empty_like(q)
|
||||
k_g = torch.empty_like(k)
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
fwd_decay_cumsum[grid](
|
||||
g_origin,
|
||||
g,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
|
||||
)
|
||||
prepare_qg_kg[grid](
|
||||
q, k, g, q_g, k_g,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
|
||||
)
|
||||
|
||||
# inter-chunk
|
||||
BT = 16
|
||||
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 2
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
|
||||
fused_chunk_gla_bwd_kernel[grid](
|
||||
q_g, k_g, v, g, do, dq, dk, dv, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
# clamp_min=-3,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
CHECK=ctx.CHECK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
|
||||
# intra chunk
|
||||
num_chunk = seq_len // BT
|
||||
v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
|
||||
do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk)
|
||||
dA2 = (do2 @ v2.transpose(-2, -1)) * scale
|
||||
dv2 = A.transpose(-1, -2) @ do2
|
||||
dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk)
|
||||
|
||||
BK = min(triton.next_power_of_2(d_head_qk), 16)
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
dk2 = torch.empty_like(k)
|
||||
dq2 = torch.empty_like(q)
|
||||
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
bwd_inner_chunk[grid](
|
||||
q, k, g,
|
||||
dA2, dq2, dk2,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, BK=BK,
|
||||
num_warps=1,
|
||||
num_stages=3
|
||||
)
|
||||
|
||||
BK = min(triton.next_power_of_2(d_head_qk), 32)
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
dg = torch.empty_like(g, dtype=torch.float32)
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
bwd_decay_global_cumsum[grid](
|
||||
dq2, dq, dk2, dk, q, k, g, dg,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, BK=BK,
|
||||
num_warps=1,
|
||||
num_stages=1
|
||||
)
|
||||
dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)
|
||||
|
||||
def rev_cumsum_exclusive(x):
|
||||
cumsum_x = x.cumsum(-2)
|
||||
rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x
|
||||
return rev_cumsum_x
|
||||
|
||||
rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])
|
||||
dg.add_(rev_cumsum_dg.unsqueeze(-2))
|
||||
dv.add_(dv2)
|
||||
dg = rearrange(dg, 'b h n c d -> b h (n c) d')
|
||||
|
||||
return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None
|
||||
|
||||
|
||||
def pad(x, chunk_size=16):
|
||||
seq_len = x.shape[-2]
|
||||
padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size
|
||||
if x.shape[-2] % chunk_size != 0:
|
||||
x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def ceildiv(a, b):
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def fused_chunk_gla(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
scale: int = -1,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if scale == -1:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
seq_len = q.shape[-2]
|
||||
q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
|
||||
o, final_state = FusedChunkGLAFunction.apply(
|
||||
q, k, v, g, scale, initial_state, output_final_state)
|
||||
o = o[..., :seq_len, :]
|
||||
return o, final_state
|
||||
138
finetune/lora/v6/fla/ops/gla/chunk_util.py
vendored
Normal file
138
finetune/lora/v6/fla/ops/gla/chunk_util.py
vendored
Normal file
@@ -0,0 +1,138 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
inv_ln2 = 1.44269504
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fwd_decay_cumsum(
|
||||
g,
|
||||
g_o,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
B,
|
||||
H,
|
||||
T,
|
||||
scale,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
DK: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
cum_decay = tl.zeros([BK], dtype=tl.float32)
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
|
||||
for i in range(BT):
|
||||
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
cum_decay += _g * inv_ln2
|
||||
tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)
|
||||
p_g += DK
|
||||
p_go += DK
|
||||
|
||||
@triton.jit
|
||||
def prepare_qg_kg(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
qg,
|
||||
kg,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
B,
|
||||
H,
|
||||
T,
|
||||
scale,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
DK: tl.constexpr
|
||||
):
|
||||
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
|
||||
last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))
|
||||
|
||||
for i in range(BT):
|
||||
_q = tl.load(p_q, mask=mask, other=0)
|
||||
_k = tl.load(p_k, mask=mask, other=0)
|
||||
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
_q *= tl.math.exp2(_g) * scale
|
||||
_k *= tl.math.exp2(last_decay - _g)
|
||||
tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)
|
||||
tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)
|
||||
p_q += DK
|
||||
p_g += DK
|
||||
p_k += DK
|
||||
p_kg += DK
|
||||
p_qg += DK
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_decay_global_cumsum(
|
||||
dq_inner,
|
||||
dq_inter,
|
||||
dk_inner,
|
||||
dk_inter,
|
||||
q, k, g, dg,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
B,
|
||||
H,
|
||||
T,
|
||||
scale,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
DK: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
last_g = tl.zeros([BK], dtype=tl.float32)
|
||||
for j in range(BT-1, -1, -1):
|
||||
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
if j == (BT-1):
|
||||
last_g = _g
|
||||
_dq1 = tl.load(p_dq_inner, mask=mask, other=0)
|
||||
_dq2 = tl.load(p_dq_inter, mask=mask, other=0)
|
||||
_dq2 *= tl.math.exp2(_g)
|
||||
_dq = _dq1 + _dq2
|
||||
tl.store(p_dq_inter, _dq, mask=mask)
|
||||
_dk1 = tl.load(p_dk_inner, mask=mask, other=0)
|
||||
_dk2 = tl.load(p_dk_inter, mask=mask, other=0)
|
||||
_dk2 *= tl.math.exp2(last_g - _g)
|
||||
_dk = _dk1 + _dk2
|
||||
tl.store(p_dk_inter, _dk, mask=mask)
|
||||
_q = tl.load(p_q, mask=mask, other=0)
|
||||
_k = tl.load(p_k, mask=mask, other=0)
|
||||
_dg = _dq * _q - _dk * _k
|
||||
cum_grad_dg += _dg
|
||||
tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
|
||||
p_g -= DK
|
||||
p_k -= DK
|
||||
p_q -= DK
|
||||
p_dq_inner -= DK
|
||||
p_dk_inner -= DK
|
||||
p_dq_inter -= DK
|
||||
p_dk_inter -= DK
|
||||
p_dg -= DK
|
||||
|
||||
116
finetune/lora/v6/fla/ops/gla/naive.py
vendored
Normal file
116
finetune/lora/v6/fla/ops/gla/naive.py
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fla.ops.gla.recurrent_fuse import fused_recurrent_gla
|
||||
|
||||
|
||||
def ceildiv(a, b):
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def naive_recurrent_gla(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
gk,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
causal=True
|
||||
):
|
||||
orig_dtype = q.dtype
|
||||
q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk))
|
||||
batch_size, n_heads, seq_len, d_head_k = q.shape
|
||||
_, _, _, d_head_v = v.shape
|
||||
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
|
||||
o = torch.zeros_like(v)
|
||||
scale = d_head_k ** -0.5
|
||||
|
||||
if initial_state is not None:
|
||||
h += initial_state
|
||||
|
||||
for i in range(seq_len):
|
||||
q_i = q[:, :, i, :] * scale
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i, :]
|
||||
gk_i = gk[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
h = h * gk_i[..., None] + kv_i
|
||||
o_i = (q_i[..., None] * h).sum(-2)
|
||||
o[:, :, i] = o_i
|
||||
|
||||
if causal:
|
||||
return o.to(orig_dtype), h
|
||||
else:
|
||||
o_reverse = torch.zeros_like(v)
|
||||
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
|
||||
for i in range(seq_len-1, -1, -1):
|
||||
q_i = q[:, :, i, :] * scale
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i, :]
|
||||
gk_i = gk[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
h = h * gk_i[..., None] + kv_i
|
||||
o_i = (q_i[..., None] * h).sum(-2)
|
||||
o_reverse[:, :, i] = o_i
|
||||
|
||||
return o, o_reverse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
B = 4
|
||||
H = 4
|
||||
L = 512
|
||||
D = 128
|
||||
dtype = torch.float32
|
||||
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
|
||||
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
|
||||
v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True)
|
||||
g = F.logsigmoid(torch.rand(B, H, L, D)).cuda(
|
||||
).clamp_min(-1).to(torch.float32).requires_grad_(True)
|
||||
|
||||
do = torch.rand_like(v).cuda()
|
||||
do2 = torch.rand_like(v).cuda()
|
||||
intial_state = torch.rand(B, H, D, D).cuda()
|
||||
|
||||
ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False)
|
||||
|
||||
ref.backward(do, retain_graph=True)
|
||||
ref_rev.backward(do2, retain_graph=True)
|
||||
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dg, g.grad = g.grad.clone(), None
|
||||
|
||||
tri, tri_rev = fused_recurrent_gla(
|
||||
q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False)
|
||||
tri.backward(do, retain_graph=True)
|
||||
tri_rev.backward(do2, retain_graph=True)
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dg, g.grad = g.grad.clone(), None
|
||||
|
||||
assert ref.allclose(tri, 0, 1e-5), breakpoint()
|
||||
assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint()
|
||||
assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
|
||||
assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
|
||||
assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
|
||||
assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
|
||||
|
||||
# tri = fused_chunk_gla(q, k, v, g)
|
||||
# tri.backward(do, retain_graph=True)
|
||||
# tri_dq, q.grad = q.grad.clone(), None
|
||||
# tri_dk, k.grad = k.grad.clone(), None
|
||||
# tri_dv, v.grad = v.grad.clone(), None
|
||||
# tri_dg, g.grad = g.grad.clone(), None
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-5), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
|
||||
# assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
|
||||
# breakpoint()
|
||||
print("Pass")
|
||||
404
finetune/lora/v6/fla/ops/gla/recurrent_fuse.py
vendored
Normal file
404
finetune/lora/v6/fla/ops/gla/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,404 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_gla_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_K]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
gk, # log gate [B, H, L, D_head_K]
|
||||
gv, # log gate [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
# initial hidden state initialization [B, H, D_head_K, D_head_V]
|
||||
initial_state,
|
||||
final_state, # final hidden state [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
USE_GK: tl.constexpr, # whether to use gk
|
||||
USE_GV: tl.constexpr, # whether to use gv
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
|
||||
if USE_GK:
|
||||
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
if USE_GV:
|
||||
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
|
||||
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
|
||||
|
||||
h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
|
||||
mask_kv = mask_bk[None, :] & mask_bv[:, None]
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
if USE_GK:
|
||||
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
|
||||
h = h * _gk[None, :]
|
||||
if USE_GV:
|
||||
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
|
||||
h = h * _gv[:, None]
|
||||
h += _k[None, :] * _v[:, None]
|
||||
_o = h * _q[None, :]
|
||||
_o = tl.sum(_o, axis=1)
|
||||
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
|
||||
p_q += -DK if REVERSE else DK
|
||||
p_k += -DK if REVERSE else DK
|
||||
p_o += -DV if REVERSE else DV
|
||||
p_v += -DV if REVERSE else DV
|
||||
if USE_GK:
|
||||
p_gk += -DK if REVERSE else DK
|
||||
if USE_GV:
|
||||
p_gv += -DV if REVERSE else DV
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final_s = final_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_recurrent_gla_bwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
gk, # log gate [B, H, L, D_head_K] \alpha
|
||||
gv, # log gate [B, H, L, D_head_V] \bete
|
||||
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
|
||||
# initial hidden state initialization [B, H, D_head_K, D_head_V]
|
||||
initial_state,
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
USE_GK: tl.constexpr, # whether to use gk
|
||||
USE_GV: tl.constexpr, # whether to use gv
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
if USE_GK:
|
||||
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
if USE_GV:
|
||||
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < DK
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < DV
|
||||
mask_kv = mask_bk[:, None] & mask_bv[None, :]
|
||||
h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[:, None]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[None, :])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for i in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
if USE_GK:
|
||||
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
|
||||
h = h * _gk[:, None]
|
||||
if USE_GV:
|
||||
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
|
||||
h = h * _gv[None, :]
|
||||
h += _k[:, None] * _v[None, :]
|
||||
_d_q = h * _do[None, :]
|
||||
d_q = tl.sum(_d_q, axis=1) * scale
|
||||
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
|
||||
|
||||
p_k += -DK if REVERSE else DK
|
||||
p_v += -DV if REVERSE else DV
|
||||
p_q += -DK if REVERSE else DK
|
||||
p_do += -DV if REVERSE else DV
|
||||
p_dq += -DK if REVERSE else DK
|
||||
if USE_GK:
|
||||
p_gk += -DK if REVERSE else DK
|
||||
if USE_GV:
|
||||
p_gv += -DV if REVERSE else DV
|
||||
|
||||
# sync threads
|
||||
tl.debug_barrier()
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
|
||||
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
|
||||
BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
|
||||
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
|
||||
BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
|
||||
if USE_GK:
|
||||
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
|
||||
if USE_GV:
|
||||
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
|
||||
|
||||
d_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
for _ in range(T):
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
d_h += _q[:, None] * _do[None, :]
|
||||
d_k = tl.sum(d_h * _v[None, :], axis=1)
|
||||
d_v = tl.sum(d_h * _k[:, None], axis=0)
|
||||
if USE_GK:
|
||||
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
|
||||
d_h *= _gk[:, None]
|
||||
if USE_GV:
|
||||
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
|
||||
d_h *= _gv[None, :]
|
||||
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
p_do += DV if REVERSE else -DV
|
||||
p_q += DK if REVERSE else -DK
|
||||
p_k += DK if REVERSE else -DK
|
||||
p_v += DV if REVERSE else -DV
|
||||
p_dk += DK if REVERSE else -DK
|
||||
p_dv += DV if REVERSE else -DV
|
||||
if USE_GK:
|
||||
p_gk += DK if REVERSE else -DK
|
||||
if USE_GV:
|
||||
p_gv += DV if REVERSE else -DV
|
||||
|
||||
|
||||
class FusedRecurrentGLAFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
# default scale
|
||||
if scale is None:
|
||||
scale = d_head_qk ** -0.5
|
||||
if gk is not None:
|
||||
gk = gk.float().exp()
|
||||
if gv is not None:
|
||||
gv = gv.float().exp()
|
||||
|
||||
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, dtype=torch.float32)
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
|
||||
else:
|
||||
final_state = None
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_recurrent_gla_fwd_kernel[grid](
|
||||
q, k, v, gk, gv, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None,
|
||||
USE_GK=gk is not None,
|
||||
USE_GV=gv is not None,
|
||||
REVERSE=reverse,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)
|
||||
ctx.scale = scale
|
||||
ctx.reverse = reverse
|
||||
# we do not need the gradient of the final state from the next chunk
|
||||
# similiar to Trunctated BPTT
|
||||
if final_state is not None:
|
||||
final_state = final_state.detach()
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, gk, gv, initial_state, o = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = ctx.scale
|
||||
|
||||
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=torch.float32)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=torch.float32)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, dtype=torch.float32)
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
|
||||
fused_recurrent_gla_bwd_kernel[grid](
|
||||
q, k, v, gk, gv, do, dq, dk, dv, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
REVERSE=ctx.reverse,
|
||||
USE_GK=gk is not None,
|
||||
USE_GV=gv is not None
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
if gk is not None:
|
||||
_dgk = dq * q.float() - dk * k.float()
|
||||
if ctx.reverse:
|
||||
dgk = _dgk.cumsum(-2)
|
||||
else:
|
||||
_dgk_cumsum = _dgk.cumsum(-2)
|
||||
dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum
|
||||
else:
|
||||
dgk = None
|
||||
|
||||
if gv is not None:
|
||||
_dgv = do.float() * o.float() - dv * v.float()
|
||||
if ctx.reverse:
|
||||
dgv = _dgv.cumsum(-2)
|
||||
else:
|
||||
_dgv_cumsum = _dgv.cumsum(-2)
|
||||
dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum
|
||||
else:
|
||||
dgv = None
|
||||
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None
|
||||
|
||||
|
||||
# if scale is None, use d_head_qk ** -0.5 by default. Otherwise specify the scale yourself. e.g. scale = 1.0
|
||||
def fused_recurrent_gla(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
gk: torch.Tensor = None,
|
||||
gv: torch.Tensor = None,
|
||||
scale: int = -1,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
causal: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if scale == -1:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
if causal:
|
||||
o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
else:
|
||||
# do not support initial_state yet. looks very strange for bidirectional modeling
|
||||
assert initial_state is None
|
||||
assert output_final_state is False
|
||||
o, final_state = FusedRecurrentGLAFunction.apply(
|
||||
q, k, v, gk, gv, scale, initial_state, output_final_state, False)
|
||||
o_reversed, final_state = FusedRecurrentGLAFunction.apply(
|
||||
q, k, v, gk, gv, scale, initial_state, output_final_state, True)
|
||||
return [o, o_reversed]
|
||||
9
finetune/lora/v6/fla/ops/hgrn/__init__.py
vendored
Normal file
9
finetune/lora/v6/fla/ops/hgrn/__init__.py
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_hgrn
|
||||
from .recurrent_fuse import fused_recurrent_hgrn
|
||||
|
||||
__all__ = [
|
||||
'chunk_hgrn',
|
||||
'fused_recurrent_hgrn'
|
||||
]
|
||||
373
finetune/lora/v6/fla/ops/hgrn/chunk.py
vendored
Normal file
373
finetune/lora/v6/fla/ops/hgrn/chunk.py
vendored
Normal file
@@ -0,0 +1,373 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2024, Yu Zhang, Songlin Yang
|
||||
|
||||
# this function implements the chunkwise form of HGRN, inspired by
|
||||
# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html)
|
||||
# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan
|
||||
|
||||
# from tests on H800, with B, H, D = 16, 4, 128, we see that the chunk can be greatly faster than the recurrent:
|
||||
#
|
||||
# Performance:
|
||||
# seq_len chunk recurrent chunk_bwd recurrent_bwd
|
||||
# 0 128.0 0.039360 0.061056 0.312160 0.205008
|
||||
# 1 256.0 0.045824 0.123712 0.308784 0.297696
|
||||
# 2 512.0 0.058688 0.241952 0.310720 0.626528
|
||||
# 3 1024.0 0.088288 0.476992 0.313184 1.333152
|
||||
# 4 2048.0 0.169472 0.943264 0.452464 2.724864
|
||||
# 5 4096.0 0.329920 1.886144 0.881600 5.551520
|
||||
# 6 8192.0 0.647872 3.755040 1.740496 11.117184
|
||||
# 7 16384.0 1.272064 7.520576 3.446608 22.362528
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BD': 32}, num_warps=1),
|
||||
triton.Config({'BD': 32}, num_warps=2),
|
||||
triton.Config({'BD': 32}, num_warps=4),
|
||||
triton.Config({'BD': 32}, num_warps=8),
|
||||
triton.Config({'BD': 64}, num_warps=1),
|
||||
triton.Config({'BD': 64}, num_warps=2),
|
||||
triton.Config({'BD': 64}, num_warps=4),
|
||||
triton.Config({'BD': 64}, num_warps=8),
|
||||
triton.Config({'BD': 128}, num_warps=1),
|
||||
triton.Config({'BD': 128}, num_warps=2),
|
||||
triton.Config({'BD': 128}, num_warps=4),
|
||||
triton.Config({'BD': 128}, num_warps=8),
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_hgrn_fwd_kernel_h(
|
||||
x,
|
||||
g,
|
||||
gc,
|
||||
o,
|
||||
h0,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr
|
||||
):
|
||||
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
p_x = x + i_bh * T * D + i_t * BT * D + o_d
|
||||
p_g = g + i_bh * T * D + i_t * BT * D + o_d
|
||||
p_gc = gc + i_bh * T * D + i_t * BT * D + o_d
|
||||
p_o = o + i_bh * T * D + i_t * BT * D + o_d
|
||||
|
||||
b_h = tl.zeros([BD], dtype=tl.float32)
|
||||
b_gc = tl.zeros([BD], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if i_t == 0:
|
||||
b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
|
||||
for i in range(0, BT):
|
||||
mask_t = mask & ((i_t * BT + i) < T)
|
||||
b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)
|
||||
b_h = tl.exp(b_g) * b_h + b_x
|
||||
b_gc = b_gc + b_g
|
||||
tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)
|
||||
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)
|
||||
|
||||
p_x += D
|
||||
p_g += D
|
||||
p_gc += D
|
||||
p_o += D
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_hgrn_fwd_kernel_o(
|
||||
gc,
|
||||
o,
|
||||
s_h,
|
||||
s_t,
|
||||
s_d,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr
|
||||
):
|
||||
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
for i_t in range(1, tl.cdiv(T, BT)):
|
||||
p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
|
||||
# [BD,]
|
||||
b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)
|
||||
# [BT, BD]
|
||||
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = b_o + tl.exp(b_gc) * b_h0[None, :]
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BD': 32}, num_warps=1),
|
||||
triton.Config({'BD': 32}, num_warps=2),
|
||||
triton.Config({'BD': 32}, num_warps=4),
|
||||
triton.Config({'BD': 32}, num_warps=8),
|
||||
triton.Config({'BD': 64}, num_warps=1),
|
||||
triton.Config({'BD': 64}, num_warps=2),
|
||||
triton.Config({'BD': 64}, num_warps=4),
|
||||
triton.Config({'BD': 64}, num_warps=8),
|
||||
triton.Config({'BD': 128}, num_warps=1),
|
||||
triton.Config({'BD': 128}, num_warps=2),
|
||||
triton.Config({'BD': 128}, num_warps=4),
|
||||
triton.Config({'BD': 128}, num_warps=8),
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_hgrn_bwd_kernel_h(
|
||||
g,
|
||||
gc,
|
||||
dx,
|
||||
do,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr
|
||||
):
|
||||
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
BC = min(BT, T - i_t * BT)
|
||||
NT = tl.num_programs(1)
|
||||
|
||||
p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d
|
||||
p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d
|
||||
p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d
|
||||
p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d
|
||||
|
||||
if i_t == NT - 1:
|
||||
b_gc = tl.zeros([BD], dtype=tl.float32)
|
||||
else:
|
||||
b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)
|
||||
b_dh = tl.zeros([BD], dtype=tl.float32)
|
||||
for _ in range(BC - 1, -1, -1):
|
||||
tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)
|
||||
|
||||
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
|
||||
|
||||
b_gc = b_gc + b_g
|
||||
b_dh = b_dh + b_do
|
||||
b_dx = b_dh
|
||||
b_dh = b_dh * tl.exp(b_g)
|
||||
|
||||
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
|
||||
|
||||
p_g -= D
|
||||
p_gc -= D
|
||||
p_dx -= D
|
||||
p_do -= D
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_hgrn_bwd_kernel_o(
|
||||
g,
|
||||
gc,
|
||||
o,
|
||||
dx,
|
||||
dg,
|
||||
s_h,
|
||||
s_t,
|
||||
s_d,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr
|
||||
):
|
||||
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))
|
||||
p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
|
||||
# [BD,]
|
||||
mask_t = mask & ((i_t + 1) * BT < T)
|
||||
b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)
|
||||
# [BT, BD]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]
|
||||
b_dg = b_o * b_dx * tl.exp(b_g)
|
||||
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class ChunkHGRNFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, x, g, initial_state=None, output_final_state=False):
|
||||
B, H, T, D = x.shape
|
||||
BT, BD = 128, min(64, triton.next_power_of_2(D))
|
||||
num_warps = 8 if BD == 64 else 4
|
||||
|
||||
gc = torch.empty_like(g, dtype=torch.float)
|
||||
o = torch.empty_like(x, dtype=torch.float)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
|
||||
chunk_hgrn_fwd_kernel_h[grid](
|
||||
x, g, gc, o, initial_state,
|
||||
T, D,
|
||||
BT=BT,
|
||||
USE_INITIAL_STATE=initial_state is not None
|
||||
)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
||||
chunk_hgrn_fwd_kernel_o[grid](
|
||||
gc, o,
|
||||
o.stride(1), o.stride(2), o.stride(3),
|
||||
T, D,
|
||||
BT=BT, BD=BD,
|
||||
num_warps=num_warps
|
||||
)
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = o[:, :, -1].clone()
|
||||
o = o.to(x.dtype)
|
||||
ctx.save_for_backward(g, o, initial_state)
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, dht=None):
|
||||
g, o, initial_state = ctx.saved_tensors
|
||||
B, H, T, D = do.shape
|
||||
BT, BD = 128, min(64, triton.next_power_of_2(D))
|
||||
num_warps = 8 if BD == 64 else 4
|
||||
|
||||
gc = torch.empty_like(g, dtype=torch.float)
|
||||
dx = torch.empty_like(o)
|
||||
dg = torch.empty_like(g)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
|
||||
chunk_hgrn_bwd_kernel_h[grid](
|
||||
g, gc, dx, do,
|
||||
T, D,
|
||||
BT=BT
|
||||
)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
||||
chunk_hgrn_bwd_kernel_o[grid](
|
||||
g, gc, o, dx, dg,
|
||||
o.stride(1), o.stride(2), o.stride(3),
|
||||
T, D,
|
||||
BT=BT, BD=BD,
|
||||
num_warps=num_warps
|
||||
)
|
||||
if initial_state is not None:
|
||||
dg[:, :, 0] = initial_state * dx[:, :, 0] * g[:, :, 0].exp()
|
||||
|
||||
return dx, dg, None, None
|
||||
|
||||
|
||||
def chunk_hgrn(
|
||||
x: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fla.ops.hgrn.naive import naive_recurrent_hgrn
|
||||
from fla.ops.hgrn.recurrent_fuse import fused_recurrent_hgrn
|
||||
B, H, T, D = 8, 4, 512, 128
|
||||
dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
# [batch_size, n_heads, seq_len, d_head]
|
||||
x = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
|
||||
g = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
|
||||
x, g = (1 - g.sigmoid()) * x, F.logsigmoid(g)
|
||||
print(f'x:\t{float(x.min()):>10.6f}\t{float(x.max()):>10.6f}')
|
||||
print(f'g:\t{float(g.min()):>10.6f}\t{float(g.max()):>10.6f}')
|
||||
x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
|
||||
print(f"DTYPE:\t{x.dtype}")
|
||||
do = torch.randn_like(x)
|
||||
h0 = torch.randn_like(x[:, :, 0])
|
||||
ref, ref_ht = naive_recurrent_hgrn(x, g, h0, output_final_state=True)
|
||||
ref.backward(do)
|
||||
ref_dx, x.grad = x.grad.clone(), None
|
||||
ref_dg, g.grad = g.grad.clone(), None
|
||||
|
||||
tri, tri_ht = fused_recurrent_hgrn(x, g, h0, output_final_state=True)
|
||||
tri.backward(do)
|
||||
tri_dx, x.grad = x.grad.clone(), None
|
||||
tri_dg, g.grad = g.grad.clone(), None
|
||||
print(" \t DIFF\t MAX")
|
||||
print(' o\t', f"{float((ref - tri).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
|
||||
print('ht\t', f"{float((ref_ht[0] - tri_ht[0]).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
|
||||
print('dx\t', f"{float((ref_dx - tri_dx).abs().max()):>10.6f}\t{float(ref_dx.max()):>10.6f}")
|
||||
print('dg\t', f"{float((ref_dg - tri_dg).abs().max()):>10.6f}\t{float(ref_dg.max()):>10.6f}")
|
||||
print('Done!')
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
# argument names to use as an x-axis for the plot
|
||||
x_names=['seq_len'],
|
||||
# different possible values for `x_name`
|
||||
x_vals=[128 * 2 ** i for i in range(0, 8)],
|
||||
# argument name whose value corresponds to a different line in the plot
|
||||
line_arg='provider',
|
||||
# possible values for `line_arg``
|
||||
line_vals=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
|
||||
# label name for the lines
|
||||
line_names=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
|
||||
ylabel="Execution Time (ms)", # label name for the y-axis
|
||||
# name for the plot. Used also as a file name for saving the plot.
|
||||
plot_name="Performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
B, H, D = 16, 4, 128
|
||||
|
||||
x = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda')
|
||||
g = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda').sigmoid()
|
||||
x = (1 - g) * x
|
||||
x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
|
||||
do = torch.randn_like(x, dtype=dtype)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
results = 0, 0, 0
|
||||
if provider == 'chunk':
|
||||
results = triton.testing.do_bench(lambda: chunk_hgrn(x, g), quantiles=quantiles)
|
||||
if provider == 'recurrent':
|
||||
results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g), quantiles=quantiles)
|
||||
if provider == 'chunk_bwd':
|
||||
results = triton.testing.do_bench(lambda: chunk_hgrn(x, g)[0].backward(do), quantiles=quantiles)
|
||||
if provider == 'recurrent_bwd':
|
||||
results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g)[0].backward(do), quantiles=quantiles)
|
||||
return results
|
||||
benchmark.run(print_data=True)
|
||||
31
finetune/lora/v6/fla/ops/hgrn/naive.py
vendored
Normal file
31
finetune/lora/v6/fla/ops/hgrn/naive.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def naive_recurrent_hgrn(
|
||||
x: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: Optional[bool] = False
|
||||
) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
x, g = map(lambda i: i.float(), (x, g))
|
||||
B, H, T, D = x.shape
|
||||
|
||||
h = torch.zeros(B, H, D, dtype=torch.float, device=x.device)
|
||||
o = torch.zeros_like(x)
|
||||
|
||||
final_state = None
|
||||
if initial_state is not None:
|
||||
h += initial_state.detach()
|
||||
|
||||
for i in range(T):
|
||||
h = g[:, :, i].exp() * h + x[:, :, i]
|
||||
o[:, :, i] = h
|
||||
|
||||
if output_final_state:
|
||||
final_state = h
|
||||
return o.to(dtype), final_state
|
||||
185
finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py
vendored
Normal file
185
finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,185 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BD': 32}, num_warps=1),
|
||||
triton.Config({'BD': 32}, num_warps=2),
|
||||
triton.Config({'BD': 32}, num_warps=4),
|
||||
triton.Config({'BD': 32}, num_warps=8),
|
||||
triton.Config({'BD': 64}, num_warps=1),
|
||||
triton.Config({'BD': 64}, num_warps=2),
|
||||
triton.Config({'BD': 64}, num_warps=4),
|
||||
triton.Config({'BD': 64}, num_warps=8),
|
||||
triton.Config({'BD': 128}, num_warps=1),
|
||||
triton.Config({'BD': 128}, num_warps=2),
|
||||
triton.Config({'BD': 128}, num_warps=4),
|
||||
triton.Config({'BD': 128}, num_warps=8),
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def fused_recurrent_hgrn_fwd_kernel(
|
||||
x,
|
||||
g,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
p_x = x + i_bh * T * D + o_d
|
||||
p_g = g + i_bh * T * D + o_d
|
||||
p_o = o + i_bh * T * D + o_d
|
||||
|
||||
b_h = tl.zeros([BD], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = h0 + i_bh * D + o_d
|
||||
b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
|
||||
for _ in range(0, T):
|
||||
b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
b_h = tl.exp(b_g) * b_h + b_x
|
||||
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
|
||||
|
||||
p_x += D
|
||||
p_g += D
|
||||
p_o += D
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = ht + i_bh * D + o_d
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BD': 32}, num_warps=1),
|
||||
triton.Config({'BD': 32}, num_warps=2),
|
||||
triton.Config({'BD': 32}, num_warps=4),
|
||||
triton.Config({'BD': 32}, num_warps=8),
|
||||
triton.Config({'BD': 64}, num_warps=1),
|
||||
triton.Config({'BD': 64}, num_warps=2),
|
||||
triton.Config({'BD': 64}, num_warps=4),
|
||||
triton.Config({'BD': 64}, num_warps=8),
|
||||
triton.Config({'BD': 128}, num_warps=1),
|
||||
triton.Config({'BD': 128}, num_warps=2),
|
||||
triton.Config({'BD': 128}, num_warps=4),
|
||||
triton.Config({'BD': 128}, num_warps=8),
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def fused_recurrent_hgrn_bwd_kernel(
|
||||
g,
|
||||
o,
|
||||
dx,
|
||||
dg,
|
||||
do,
|
||||
h0,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr
|
||||
):
|
||||
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
p_g = g + (i_bh * T + T - 1) * D + o_d
|
||||
p_o = o + (i_bh * T + T - 2) * D + o_d
|
||||
p_dx = dx + (i_bh * T + T - 1) * D + o_d
|
||||
p_dg = dg + (i_bh * T + T - 1) * D + o_d
|
||||
p_do = do + (i_bh * T + T - 1) * D + o_d
|
||||
|
||||
b_dh = tl.zeros([BD], dtype=tl.float32)
|
||||
for i in range(T - 1, -1, -1):
|
||||
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
|
||||
if i > 0:
|
||||
b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
|
||||
elif USE_INITIAL_STATE:
|
||||
b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
|
||||
else:
|
||||
b_o = tl.zeros([BD], dtype=tl.float32)
|
||||
|
||||
b_dh = b_dh + b_do
|
||||
b_dx = b_dh
|
||||
b_dh = b_dh * tl.exp(b_g)
|
||||
b_dg = b_dh * b_o
|
||||
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
|
||||
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
|
||||
|
||||
p_g -= D
|
||||
p_o -= D
|
||||
p_dx -= D
|
||||
p_dg -= D
|
||||
p_do -= D
|
||||
|
||||
|
||||
class FusedRecurrentHGRNFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, x, g, initial_state=None, output_final_state=False):
|
||||
B, H, T, D = x.shape
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = x.new_empty(B, H, D)
|
||||
|
||||
o = torch.empty_like(x)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
||||
fused_recurrent_hgrn_fwd_kernel[grid](
|
||||
x, g, o, initial_state, final_state,
|
||||
T, D,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None
|
||||
)
|
||||
ctx.save_for_backward(g, o, initial_state)
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, dht=None):
|
||||
g, o, initial_state = ctx.saved_tensors
|
||||
B, H, T, D = do.shape
|
||||
|
||||
dx = torch.empty_like(o)
|
||||
dg = torch.empty_like(g)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
||||
fused_recurrent_hgrn_bwd_kernel[grid](
|
||||
g, o, dx, dg, do, initial_state,
|
||||
T, D,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
)
|
||||
|
||||
return dx, dg, None, None
|
||||
|
||||
|
||||
def fused_recurrent_hgrn(
|
||||
x: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
12
finetune/lora/v6/fla/ops/linear_attn/__init__.py
vendored
Normal file
12
finetune/lora/v6/fla/ops/linear_attn/__init__.py
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_linear_attn
|
||||
from .chunk_fuse import fused_chunk_linear_attn
|
||||
from .recurrent_fuse import fused_recurrent_linear_attn
|
||||
|
||||
__all__ = [
|
||||
'chunk_linear_attn',
|
||||
'fused_chunk_linear_attn',
|
||||
'fused_recurrent_linear_attn'
|
||||
]
|
||||
|
||||
359
finetune/lora/v6/fla/ops/linear_attn/chunk.py
vendored
Normal file
359
finetune/lora/v6/fla/ops/linear_attn/chunk.py
vendored
Normal file
@@ -0,0 +1,359 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def normalize_output(q, k, o):
|
||||
k = k.transpose(-2, -1)
|
||||
k = k.cumsum(-1)
|
||||
k = k.transpose(-2, -1)
|
||||
z = (q * k).sum(-1, keepdim=True)
|
||||
return o / (z + 1e-5)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_linear_attn_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i_t in range(NT):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h += tl.dot(b_k, b_v, allow_tf32=False)
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_linear_attn_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
o,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_s = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
b_o += tl.dot(b_q, b_h, allow_tf32=False)
|
||||
b_s += tl.dot(b_q, b_k, allow_tf32=False)
|
||||
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_linear_attn_bwd_kernel_dh(
|
||||
q,
|
||||
do,
|
||||
dh,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, V]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_linear_attn_bwd_kernel_dqkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
n_bh = tl.num_programs(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
|
||||
b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
# [BT, BV]
|
||||
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
|
||||
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class ChunkLinearAttentionFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT = 64
|
||||
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
ctx.scale = scale
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
|
||||
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_linear_attn_fwd_kernel_h[grid](
|
||||
k, v, h, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=output_final_state,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
grid = (NV, NT, B * H)
|
||||
o = torch.empty_like(v)
|
||||
chunk_linear_attn_fwd_kernel_o[grid](
|
||||
q, k, v, h, o,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, h)
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, d_ht=None):
|
||||
q, k, v, h = ctx.saved_tensors
|
||||
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT = 64
|
||||
BK, BV = min(64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
scale = ctx.scale
|
||||
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_linear_attn_bwd_kernel_dh[grid](
|
||||
q, do, dh,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
grid = (NK, NT, B * H)
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = v.new_empty(NK, *v.shape)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
chunk_linear_attn_bwd_kernel_dqkv[grid](
|
||||
q, k, v, h, do, dh, dq, dk, dv,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dv = dv.sum(0)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
|
||||
|
||||
|
||||
def chunk_linear_attn(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float = -1,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
normalize: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if scale == -1:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = ChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
|
||||
|
||||
if normalize:
|
||||
o = normalize_output(q * scale, k, o)
|
||||
|
||||
return o, final_state
|
||||
326
finetune/lora/v6/fla/ops/linear_attn/chunk_fuse.py
vendored
Normal file
326
finetune/lora/v6/fla/ops/linear_attn/chunk_fuse.py
vendored
Normal file
@@ -0,0 +1,326 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def normalize_output(q, k, o):
|
||||
k = k.transpose(-2, -1)
|
||||
k = k.cumsum(-1)
|
||||
k = k.transpose(-2, -1)
|
||||
z = (q * k).sum(-1, keepdim=True)
|
||||
return o / (z + 1e-5)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_chunk_linear_attn_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
# [BT, BT]
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
# [BK, BV]
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
# make block pointers
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_k.dtype)
|
||||
|
||||
# [BT, BT]
|
||||
b_s = tl.dot(b_q, b_k, allow_tf32=False)
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
# [BT, BV]
|
||||
b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
|
||||
if CHECK and i == 0:
|
||||
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
|
||||
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
|
||||
else:
|
||||
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
|
||||
b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
p_q = tl.advance(p_q, (BT, 0))
|
||||
p_k = tl.advance(p_k, (0, BT))
|
||||
p_v = tl.advance(p_v, (BT, 0))
|
||||
p_o = tl.advance(p_o, (BT, 0))
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_chunk_linear_attn_bwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
# [BV, BK]
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
|
||||
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [DV, BT]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BT]
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
|
||||
b_ds = tl.where(m_s, b_ds, 0)
|
||||
# [BT, DK]
|
||||
b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
|
||||
# [DV, DK]
|
||||
if CHECK and i == 0:
|
||||
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
|
||||
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
|
||||
else:
|
||||
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
|
||||
b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
|
||||
b_dq *= scale
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# sync threads
|
||||
b_h = None
|
||||
tl.debug_barrier()
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
m_s = o_i[:, None] <= o_i[None, :]
|
||||
for i in range(1, tl.cdiv(T, BT) + 1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
|
||||
# [DK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
# b_dd = (b_do]).to(b_do.dtype)
|
||||
|
||||
# [BT, BT]
|
||||
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
|
||||
b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
|
||||
# [BT, BT]
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
|
||||
b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
|
||||
# [BT, DK]
|
||||
b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
|
||||
# [BT, DV]
|
||||
b_dv = tl.dot(b_s, b_do, allow_tf32=False)
|
||||
if CHECK and i == 1:
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
|
||||
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
|
||||
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
|
||||
else:
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
|
||||
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
|
||||
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
|
||||
|
||||
tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class FusedChunkLinearAttentionFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
ctx.scale = scale
|
||||
BT = 64
|
||||
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
|
||||
else:
|
||||
final_state = None
|
||||
# the bug still exists even for Triton 2.2 on H100 GPUs
|
||||
# so we always enable initial checks
|
||||
CHECK = True
|
||||
if version.parse(triton.__version__) < version.parse('2.2.0'):
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Triton<2.2.0 detected for running this kernel, "
|
||||
"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
|
||||
"that lead to significant precision loss. "
|
||||
"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
|
||||
"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
|
||||
)
|
||||
CHECK = True
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_chunk_linear_attn_fwd_kernel[grid](
|
||||
q, k, v, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=output_final_state,
|
||||
CHECK=CHECK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
ctx.save_for_backward(q, k, v, initial_state)
|
||||
ctx.CHECK = CHECK
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, initial_state = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = ctx.scale
|
||||
|
||||
BT = 64
|
||||
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4
|
||||
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
|
||||
fused_chunk_linear_attn_bwd_kernel[grid](
|
||||
q, k, v, do, dq, dk, dv, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
CHECK=ctx.CHECK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
|
||||
|
||||
|
||||
def fused_chunk_linear_attn(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float = -1,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
normalize: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
if scale == -1:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
|
||||
if normalize:
|
||||
o = normalize_output(q * scale, k, o)
|
||||
return o, final_state
|
||||
20
finetune/lora/v6/fla/ops/linear_attn/naive.py
vendored
Normal file
20
finetune/lora/v6/fla/ops/linear_attn/naive.py
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def torch_chunk_linear_attn(q, k, v, chunk_size=64):
|
||||
q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] **-0.5)
|
||||
k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
kv = k.transpose(-1, -2) @ v
|
||||
kv = kv.cumsum(2)
|
||||
kv = torch.cat([
|
||||
torch.zeros_like(kv[:, :, :1]),
|
||||
kv[:, :, :-1]
|
||||
], dim=2)
|
||||
inter = q @ kv
|
||||
intra = ((q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)) @ v
|
||||
o = inter + intra
|
||||
return rearrange(o, 'b h n c d -> b h (n c) d')
|
||||
284
finetune/lora/v6/fla/ops/linear_attn/recurrent_fuse.py
vendored
Normal file
284
finetune/lora/v6/fla/ops/linear_attn/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,284 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def normalize_output(q, k, o):
|
||||
k = k.transpose(-2, -1)
|
||||
k = k.cumsum(-1)
|
||||
k = k.transpose(-2, -1)
|
||||
z = (q * k).sum(-1, keepdim=True)
|
||||
return o / (z + 1e-5)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_linear_attn_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
initial_state,
|
||||
final_state, # final hidden state [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
|
||||
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
|
||||
mask_kv = mask_bk[None, :] & mask_bv[:, None]
|
||||
|
||||
h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
|
||||
h += _k[None, :] * _v[:, None]
|
||||
_o = h * _q[None, :]
|
||||
_o = tl.sum(_o, axis=1)
|
||||
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
p_q += DK
|
||||
p_k += DK
|
||||
p_o += DV
|
||||
p_v += DV
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final_s = final_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_recurrent_linear_attn_bwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
|
||||
# initial hidden state initialization [B, H, D_head_K, D_head_V]
|
||||
initial_state,
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < DK
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < DV
|
||||
|
||||
h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
mask_kv = mask_bk[:, None] & mask_bv[None, :]
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[:, None]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[None, :])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for i in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
|
||||
h += _k[:, None] * _v[None, :]
|
||||
_d_q = h * _do[None, :]
|
||||
d_q = tl.sum(_d_q, axis=1) * scale
|
||||
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
|
||||
|
||||
p_k += DK
|
||||
p_do += DV
|
||||
p_v += DV
|
||||
p_dq += DK
|
||||
|
||||
# sync threads
|
||||
tl.debug_barrier()
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
|
||||
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
|
||||
BK + tl.arange(0, BK) + (T - 1) * DK
|
||||
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
|
||||
BV + tl.arange(0, BV) + (T - 1) * DV
|
||||
d_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
for _ in range(T):
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
d_h += _q[:, None] * _do[None, :]
|
||||
d_k = tl.sum(d_h * _v[None, :], axis=1)
|
||||
d_v = tl.sum(d_h * _k[:, None], axis=0)
|
||||
|
||||
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
p_do -= DV
|
||||
p_q -= DK
|
||||
p_k -= DK
|
||||
p_v -= DV
|
||||
p_dk -= DK
|
||||
p_dv -= DV
|
||||
|
||||
|
||||
class FusedRecurrentLinearAttentionFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, initial_state=None, output_final_state=False):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
|
||||
scale = d_head_qk ** -0.5
|
||||
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
|
||||
else:
|
||||
final_state = None
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_recurrent_linear_attn_fwd_kernel[grid](
|
||||
q, k, v, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
ctx.save_for_backward(q, k, v, initial_state)
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, initial_state = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = d_head_qk ** -0.5
|
||||
|
||||
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
|
||||
fused_recurrent_linear_attn_bwd_kernel[grid](
|
||||
q, k, v, do, dq, dk, dv, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
return dq, dk, dv, None, None
|
||||
|
||||
|
||||
def fused_recurrent_linear_attn(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
normalize: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = FusedRecurrentLinearAttentionFunction.apply(
|
||||
q, k, v, initial_state, output_final_state)
|
||||
if normalize:
|
||||
o = normalize_output(q, k, o)
|
||||
return o, final_state
|
||||
7
finetune/lora/v6/fla/ops/rebased/__init__.py
vendored
Normal file
7
finetune/lora/v6/fla/ops/rebased/__init__.py
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .parallel import parallel_rebased
|
||||
|
||||
__all__ = [
|
||||
'parallel_rebased'
|
||||
]
|
||||
80
finetune/lora/v6/fla/ops/rebased/naive.py
vendored
Normal file
80
finetune/lora/v6/fla/ops/rebased/naive.py
vendored
Normal file
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from fla.ops.rebased.parallel import parallel_rebased
|
||||
|
||||
def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True):
|
||||
if use_scale:
|
||||
q = q * (q.shape[-1] ** -0.5)
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = (attn ** 2)
|
||||
attn.masked_fill_(~torch.tril(torch.ones(
|
||||
q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
|
||||
o = attn @ v
|
||||
if use_norm:
|
||||
z = attn.sum(-1)
|
||||
return o / (z[..., None] + 1e-6)
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
B = 4
|
||||
H = 4
|
||||
L = 128
|
||||
# D = 15
|
||||
dtype = torch.float32
|
||||
q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
|
||||
k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True)
|
||||
v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True)
|
||||
|
||||
do = torch.randn_like(v).cuda()
|
||||
ref = naive_parallel_rebased(q, k, v, True, True)
|
||||
ref.backward(do, retain_graph=True)
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
|
||||
# tri = naive_chunk_based(q, k, v)
|
||||
# tri.backward(do, retain_graph=True)
|
||||
# tri_dq, q.grad = q.grad.clone(), None
|
||||
# tri_dk, k.grad = k.grad.clone(), None
|
||||
# tri_dv, v.grad = v.grad.clone(), None
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
|
||||
|
||||
tri = parallel_rebased(q, k, v, 1e-6, True, True)
|
||||
tri.backward(do, retain_graph=True)
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
print((ref-tri).abs().max())
|
||||
print((ref_dq-tri_dq).abs().max())
|
||||
print((ref_dk-tri_dk).abs().max())
|
||||
print((ref_dv-tri_dv).abs().max())
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
|
||||
|
||||
# tri = parallel_based(q, k, v, True, True)
|
||||
# tri.backward(do, retain_graph=True)
|
||||
# tri_dq, q.grad = q.grad.clone(), None
|
||||
# tri_dk, k.grad = k.grad.clone(), None
|
||||
# tri_dv, v.grad = v.grad.clone(), None
|
||||
|
||||
# print((ref-tri).abs().max())
|
||||
# print((ref_dq-tri_dq).abs().max())
|
||||
# print((ref_dk-tri_dk).abs().max())
|
||||
# print((ref_dv-tri_dv).abs().max())
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-4), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint()
|
||||
387
finetune/lora/v6/fla/ops/rebased/parallel.py
vendored
Normal file
387
finetune/lora/v6/fla/ops/rebased/parallel.py
vendored
Normal file
@@ -0,0 +1,387 @@
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models
|
||||
# https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
|
||||
|
||||
|
||||
@triton.jit
|
||||
def parallel_rebased_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
z, # normalizer [B, H, L]
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q
|
||||
BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
):
|
||||
# i_c: chunk index. used for sequence parallelism
|
||||
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
NV = tl.cdiv(DV, BV)
|
||||
i_k = i_kv // (NV)
|
||||
i_v = i_kv % (NV)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
|
||||
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))
|
||||
|
||||
# [BQ, BD] block Q, in the shared memory throughout the whole kernel
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
b_o = tl.zeros([BTL, BV], dtype=tl.float32)
|
||||
b_z = tl.zeros([BTL], dtype=tl.float32)
|
||||
|
||||
# Q block and K block have no overlap
|
||||
# no need for mask, thereby saving flops
|
||||
for _ in range(0, i_c * BTL, BTS):
|
||||
# [BK, BTS]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
|
||||
# [BTS, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
b_s = tl.dot(b_q, (b_k), allow_tf32=False)
|
||||
b_s = b_s * b_s
|
||||
b_z += tl.sum(b_s, axis=1)
|
||||
|
||||
# [BQ, BD]
|
||||
b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (0, BTS))
|
||||
p_v = tl.advance(p_v, (BTS, 0))
|
||||
|
||||
# # rescale interchunk output
|
||||
tl.debug_barrier()
|
||||
o_q = tl.arange(0, BTL)
|
||||
# # sync threads, easy for compiler to optimize
|
||||
# tl.debug_barrier()
|
||||
|
||||
o_k = tl.arange(0, BTS)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
|
||||
(s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
|
||||
# Q block and K block have overlap. masks required
|
||||
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
|
||||
# [BK, BTS]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BTS, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
m_s = o_q[:, None] >= o_k[None, :]
|
||||
b_s = tl.dot(b_q, b_k, allow_tf32=False)
|
||||
b_s = b_s * b_s
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
b_z += tl.sum(b_s, axis=1)
|
||||
# [BTL, BV]
|
||||
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (0, BTS))
|
||||
p_v = tl.advance(p_v, (BTS, 0))
|
||||
o_k += BTS
|
||||
|
||||
p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
|
||||
p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_z, b_z.to(p_z.dtype.element_ty),
|
||||
mask=((i_c * BTL + tl.arange(0, BTL)) < T))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _parallel_rebased_bwd_dq(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale,
|
||||
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
|
||||
DK: tl.constexpr, DV: tl.constexpr,
|
||||
):
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
|
||||
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
|
||||
p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
|
||||
(s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))
|
||||
p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
|
||||
b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
|
||||
|
||||
for _ in range(0, i_c * BTL, BTS):
|
||||
# [BTS, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BV, BTS]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[:, None]
|
||||
else:
|
||||
b_ds = b_ds
|
||||
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
|
||||
# [BQ, BD]
|
||||
b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (BTS, 0))
|
||||
p_v = tl.advance(p_v, (0, BTS))
|
||||
|
||||
b_dq *= scale
|
||||
o_q = tl.arange(0, BTL)
|
||||
o_k = tl.arange(0, BTS)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
|
||||
(s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
|
||||
# Q block and K block have overlap. masks required
|
||||
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
|
||||
# [BTS, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BV, BTS]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
m_s = o_q[:, None] >= o_k[None, :]
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[:, None]
|
||||
else:
|
||||
b_ds = b_ds
|
||||
b_ds = tl.where(m_s, b_ds, 0) * scale
|
||||
b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
# [BTL, BK]
|
||||
b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype),
|
||||
b_k, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (BTS, 0))
|
||||
p_v = tl.advance(p_v, (0, BTS))
|
||||
o_k += BTS
|
||||
p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _parallel_rebased_bwd_dkv(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale,
|
||||
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
|
||||
DK: tl.constexpr, DV: tl.constexpr,
|
||||
):
|
||||
# compute dk dv
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
|
||||
(i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
|
||||
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
|
||||
b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(
|
||||
p_v, boundary_check=(0, 1))
|
||||
b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
|
||||
[BTL, BV], dtype=tl.float32)
|
||||
|
||||
for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
|
||||
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
|
||||
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
|
||||
b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \
|
||||
scale # [BTL, BTS]
|
||||
b_s2 = b_s * b_s
|
||||
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
|
||||
b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[None, :] * scale
|
||||
else:
|
||||
b_ds = b_ds
|
||||
b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),
|
||||
tl.trans(b_q), allow_tf32=False)
|
||||
|
||||
tl.debug_barrier()
|
||||
o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
|
||||
for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
|
||||
p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
|
||||
b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
|
||||
# [BK, BQ]
|
||||
m_s = o_k[:, None] <= o_q[None, :]
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
|
||||
b_s2 = b_s * b_s
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
b_s2 = tl.where(m_s, b_s2, 0)
|
||||
|
||||
b_ds = tl.dot(b_v, b_do, allow_tf32=False)
|
||||
if i_v == 0:
|
||||
b_ds += b_dz[None, :]
|
||||
else:
|
||||
b_ds = b_ds
|
||||
b_ds = tl.where(m_s, b_ds, 0) * scale
|
||||
# [BK, BD]
|
||||
b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
|
||||
b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),
|
||||
tl.trans(b_q), allow_tf32=False)
|
||||
o_q += BTS
|
||||
|
||||
p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,
|
||||
(T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,
|
||||
(T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def parallel_rebased_bwd_kernel(
|
||||
q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale,
|
||||
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
|
||||
DK: tl.constexpr, DV: tl.constexpr,
|
||||
):
|
||||
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
NV = tl.cdiv(DV, BV)
|
||||
i_k = i_kv // (NV)
|
||||
i_v = i_kv % (NV)
|
||||
i_h = i_bh % H
|
||||
_parallel_rebased_bwd_dq(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV
|
||||
)
|
||||
tl.debug_barrier()
|
||||
_parallel_rebased_bwd_dkv(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV
|
||||
)
|
||||
|
||||
|
||||
class ParallelBasedFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, scale):
|
||||
BTL, BTS = 128, 32
|
||||
assert BTL % BTS == 0
|
||||
# assert q.shape[-1] % 16 == 0
|
||||
BK = min(128, triton.next_power_of_2(k.shape[-1]))
|
||||
BV = min(128, triton.next_power_of_2(v.shape[-1]))
|
||||
BK, BV = max(BK, 16), max(BV, 16)
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
num_stages = 2
|
||||
num_warps = 4
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
NV = triton.cdiv(d_head_v, BV)
|
||||
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
|
||||
|
||||
assert NK == 1, "will encounter some synchronization issue if not."
|
||||
|
||||
o = torch.empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, device=q.device)
|
||||
z = torch.empty(NK, batch_size, n_heads, seq_len,
|
||||
device=q.device)
|
||||
parallel_rebased_fwd_kernel[grid](
|
||||
q, k, v, o, z,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
ctx.save_for_backward(q, k, v)
|
||||
ctx.scale = scale
|
||||
return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, dz):
|
||||
q, k, v = ctx.saved_tensors
|
||||
scale = ctx.scale
|
||||
BTL, BTS = 64, 32
|
||||
assert BTL % BTS == 0
|
||||
BK = min(128, triton.next_power_of_2(k.shape[-1]))
|
||||
BV = min(128, triton.next_power_of_2(v.shape[-1]))
|
||||
BK, BV = max(BK, 16), max(BV, 16)
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
num_stages = 2
|
||||
num_warps = 4
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
NV = triton.cdiv(d_head_v, BV)
|
||||
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
|
||||
|
||||
assert NK == 1, "will encounter some synchronization issue if not"
|
||||
|
||||
dq = torch.empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=q.dtype, device=q.device)
|
||||
dk = torch.empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=q.dtype, device=q.device)
|
||||
dv = torch.empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, dtype=q.dtype, device=q.device)
|
||||
|
||||
parallel_rebased_bwd_kernel[grid](
|
||||
q, k, v, do, dz, dq, dk, dv,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
|
||||
|
||||
|
||||
triton_parallel_based = ParallelBasedFunction.apply
|
||||
|
||||
|
||||
def parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False):
|
||||
assert q.shape[-1] <= 128, "only support feature dim up to 128"
|
||||
if use_scale:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
else:
|
||||
scale = 1
|
||||
o, z = triton_parallel_based(q, k, v, scale)
|
||||
if return_both:
|
||||
return o, z
|
||||
if use_normalize:
|
||||
o = o / (z[..., None] + eps)
|
||||
else:
|
||||
o = o
|
||||
return o.to(q.dtype)
|
||||
13
finetune/lora/v6/fla/ops/retention/__init__.py
vendored
Normal file
13
finetune/lora/v6/fla/ops/retention/__init__.py
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_retention
|
||||
from .chunk_fuse import fused_chunk_retention
|
||||
from .parallel import parallel_retention
|
||||
from .recurrent_fuse import fused_recurrent_retention
|
||||
|
||||
__all__ = [
|
||||
'chunk_retention',
|
||||
'fused_chunk_retention',
|
||||
'parallel_retention',
|
||||
'fused_recurrent_retention'
|
||||
]
|
||||
364
finetune/lora/v6/fla/ops/retention/chunk.py
vendored
Normal file
364
finetune/lora/v6/fla/ops/retention/chunk.py
vendored
Normal file
@@ -0,0 +1,364 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_retention_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((BT - o_i - 1) * b_b)
|
||||
# [BK, BV]
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i_t in range(NT):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
if i_t == NT - 1 and (T % BT) != 0:
|
||||
d_b = tl.math.exp2((T % BT) * b_b)
|
||||
d_i = tl.math.exp2(((T % BT) - o_i - 1) * b_b)
|
||||
b_h = d_b * b_h + tl.dot(b_k, (b_v * d_i[:, None]).to(b_k.dtype), allow_tf32=False)
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_retention_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
o,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
d_i = tl.math.exp2((o_i + 1) * b_b)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_s = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
b_o += tl.dot((b_q * d_i[:, None]).to(b_q.dtype), b_h, allow_tf32=False)
|
||||
b_s += tl.dot(b_q, b_k, allow_tf32=False)
|
||||
|
||||
b_s *= d_s
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_retention_bwd_kernel_dh(
|
||||
q,
|
||||
do,
|
||||
dh,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b)
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, V]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = d_b * b_dh + tl.dot(b_q, (b_do * d_i[:, None]).to(b_q.dtype), allow_tf32=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_retention_bwd_kernel_dqkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
n_bh = tl.num_programs(2)
|
||||
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
d_q, d_k = tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)
|
||||
d_q = (d_q * scale).to(d_q.dtype)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False) * tl.trans(d_s)
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BT]
|
||||
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
# [BT, BV]
|
||||
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * d_k[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_ds = (b_ds * d_s).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_dq = b_dq * d_q[:, None] + tl.dot(b_ds, b_k, allow_tf32=False)
|
||||
b_dk = b_dk * d_k[:, None] + tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class ChunkRetentionFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, initial_state, output_final_state):
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT = 64
|
||||
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
scale = K ** -0.5
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
|
||||
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_retention_fwd_kernel_h[grid](
|
||||
k, v, h, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=output_final_state,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
grid = (NV, NT, B * H)
|
||||
o = torch.empty_like(v)
|
||||
chunk_retention_fwd_kernel_o[grid](
|
||||
q, k, v, h, o,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, h)
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, d_ht=None):
|
||||
q, k, v, h = ctx.saved_tensors
|
||||
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT = 64
|
||||
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
scale = K ** -0.5
|
||||
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_retention_bwd_kernel_dh[grid](
|
||||
q, do, dh,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
grid = (NK, NT, B * H)
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = v.new_empty(NK, *v.shape)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
chunk_retention_bwd_kernel_dqkv[grid](
|
||||
q, k, v, h, do, dh, dq, dk, dv,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dv = dv.sum(0)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None
|
||||
|
||||
|
||||
def chunk_retention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = ChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
334
finetune/lora/v6/fla/ops/retention/chunk_fuse.py
vendored
Normal file
334
finetune/lora/v6/fla/ops/retention/chunk_fuse.py
vendored
Normal file
@@ -0,0 +1,334 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_chunk_retention_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
# decay rate given the head index
|
||||
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
|
||||
# d_b: overall decay for the entire chunk
|
||||
# d_o: cumulative decay from the start of the chunk
|
||||
# d_h: cumulative decay from the end of the chunk
|
||||
d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)
|
||||
|
||||
# [BT, BT]
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)
|
||||
# [BK, BV]
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
# make block pointers
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
NT = tl.cdiv(T, BT)
|
||||
for i in range(0, NT):
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_k.dtype)
|
||||
|
||||
# [BT, BT]
|
||||
b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s
|
||||
# [BT, BV]
|
||||
b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
|
||||
if CHECK and i == 0:
|
||||
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
|
||||
b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
|
||||
else:
|
||||
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None]
|
||||
if i == NT - 1 and (T % BT) != 0:
|
||||
d_b = tl.math.exp2((T % BT) * b_b)
|
||||
d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b)
|
||||
b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_q = tl.advance(p_q, (BT, 0))
|
||||
p_k = tl.advance(p_k, (0, BT))
|
||||
p_v = tl.advance(p_v, (BT, 0))
|
||||
p_o = tl.advance(p_o, (BT, 0))
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_chunk_retention_bwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b)
|
||||
d_b = tl.math.exp2(BT * b_b)
|
||||
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale
|
||||
# [BV, BK]
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0))
|
||||
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [DV, BT]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
|
||||
|
||||
# [BT, BT]
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False)
|
||||
b_ds = (b_ds * d_s).to(b_k.dtype)
|
||||
# [BT, DK]
|
||||
b_dq = tl.dot(b_ds, b_k, allow_tf32=False)
|
||||
# [DV, DK]
|
||||
if CHECK and i == 0:
|
||||
b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
|
||||
b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
|
||||
else:
|
||||
b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False)
|
||||
b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False)
|
||||
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# sync threads
|
||||
b_h = None
|
||||
tl.debug_barrier()
|
||||
d_s = tl.trans(d_s)
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i in range(1, tl.cdiv(T, BT) + 1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
|
||||
# [DK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
b_dd = (b_do * d_q[:, None]).to(b_do.dtype)
|
||||
|
||||
# [BT, BT]
|
||||
b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
|
||||
b_ds = (b_ds * d_s).to(b_k.dtype)
|
||||
|
||||
# [BT, BT]
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s
|
||||
# [BT, DK]
|
||||
b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
|
||||
# [BT, DV]
|
||||
b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
|
||||
if CHECK and i == 1:
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
|
||||
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
|
||||
b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
|
||||
else:
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None]
|
||||
b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None]
|
||||
b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False)
|
||||
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class FusedChunkRetentionFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, initial_state, output_final_state):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
|
||||
scale = d_head_qk ** -0.5
|
||||
BT = 64
|
||||
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4
|
||||
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False)
|
||||
else:
|
||||
final_state = None
|
||||
# the bug still exists even for Triton 2.2 on H100 GPUs
|
||||
# so we always enable initial checks
|
||||
CHECK = True
|
||||
if version.parse(triton.__version__) < version.parse('2.2.0'):
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Triton<2.2.0 detected for running this kernel, "
|
||||
"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
|
||||
"that lead to significant precision loss. "
|
||||
"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
|
||||
"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
|
||||
)
|
||||
CHECK = True
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_chunk_retention_fwd_kernel[grid](
|
||||
q, k, v, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=output_final_state,
|
||||
CHECK=CHECK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
ctx.save_for_backward(q, k, v, initial_state)
|
||||
ctx.CHECK = CHECK
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, initial_state = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = d_head_qk ** -0.5
|
||||
|
||||
BT = 64
|
||||
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4
|
||||
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
|
||||
fused_chunk_retention_bwd_kernel[grid](
|
||||
q, k, v, do, dq, dk, dv, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
CHECK=ctx.CHECK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None
|
||||
|
||||
|
||||
def fused_chunk_retention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
15
finetune/lora/v6/fla/ops/retention/naive.py
vendored
Normal file
15
finetune/lora/v6/fla/ops/retention/naive.py
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def naive_retention(q, k, v):
|
||||
orig_type = q.dtype
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
_, n_heads, seq_len, d_head = q.shape
|
||||
s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2()
|
||||
n = q.new_tensor(range(seq_len), dtype=torch.float)
|
||||
n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n)
|
||||
s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype))
|
||||
o = torch.einsum('bhqk,bhkd->bhqd', s, v)
|
||||
return o.to(orig_type)
|
||||
339
finetune/lora/v6/fla/ops/retention/parallel.py
vendored
Normal file
339
finetune/lora/v6/fla/ops/retention/parallel.py
vendored
Normal file
@@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.jit
|
||||
def parallel_retention_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q
|
||||
BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
):
|
||||
# i_c: chunk index. used for sequence parallelism
|
||||
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
NV = tl.cdiv(DV, BV)
|
||||
i_k = i_kv // (NV)
|
||||
i_v = i_kv % (NV)
|
||||
i_h = i_bh % H
|
||||
# decay rate given the head index
|
||||
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
# cumulative decay from the end of the chunk
|
||||
o_k = tl.arange(0, BTS)
|
||||
d_h = tl.math.exp2((BTS - o_k) * b_b)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
|
||||
(s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))
|
||||
|
||||
# [BQ, BD] block Q, in the shared memory throughout the whole kernel
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
b_o = tl.zeros([BTL, BV], dtype=tl.float32)
|
||||
|
||||
# Q block and K block have no overlap
|
||||
# no need for mask, thereby saving flops
|
||||
for _ in range(0, i_c * BTL, BTS):
|
||||
# [BK, BTS]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BTS, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :]
|
||||
# [BQ, BD]
|
||||
b_o = b_o * tl.math.exp2(b_b * BTS)
|
||||
b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (0, BTS))
|
||||
p_v = tl.advance(p_v, (BTS, 0))
|
||||
|
||||
# # rescale interchunk output
|
||||
tl.debug_barrier()
|
||||
o_q = tl.arange(0, BTL)
|
||||
d_q = tl.math.exp2(tl.arange(0, BTL) * b_b)
|
||||
b_o *= d_q[:, None]
|
||||
# # sync threads, easy for compiler to optimize
|
||||
# tl.debug_barrier()
|
||||
|
||||
o_k = tl.arange(0, BTS)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),
|
||||
(s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
|
||||
# Q block and K block have overlap. masks required
|
||||
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
|
||||
# [BK, BTS]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BTS, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
m_s = o_q[:, None] >= o_k[None, :]
|
||||
d_s = tl.where(m_s, tl.math.exp2(
|
||||
(o_q[:, None] - o_k[None, :]) * b_b), 0)
|
||||
b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s
|
||||
# [BTL, BV]
|
||||
b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
|
||||
|
||||
p_k = tl.advance(p_k, (0, BTS))
|
||||
p_v = tl.advance(p_v, (BTS, 0))
|
||||
o_k += BTS
|
||||
|
||||
p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _parallel_retention_bwd_dq(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale,
|
||||
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
|
||||
DK: tl.constexpr, DV: tl.constexpr,
|
||||
):
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
|
||||
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
|
||||
(s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))
|
||||
# decay rate given the head index
|
||||
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
# overall decay rate for an entire block
|
||||
d_b = tl.math.exp2(b_b * BTS)
|
||||
# cumulative decay from the end of the chunk
|
||||
d_h = tl.math.exp2((BTS - tl.arange(0, BTS)) * b_b)
|
||||
for _ in range(0, i_c * BTL, BTS):
|
||||
# [BTS, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BV, BTS]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_h[None, :]
|
||||
# [BQ, BD]
|
||||
b_dq *= d_b
|
||||
b_dq += tl.dot(b_ds.to(b_v.dtype), b_k, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (BTS, 0))
|
||||
p_v = tl.advance(p_v, (0, BTS))
|
||||
b_dq *= tl.math.exp2(tl.arange(0, BTL) * b_b)[:, None] * scale
|
||||
o_q = tl.arange(0, BTL)
|
||||
o_k = tl.arange(0, BTS)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),
|
||||
(s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
|
||||
# Q block and K block have overlap. masks required
|
||||
for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
|
||||
# [BTS, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BV, BTS]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BTL, BTS]
|
||||
m_s = o_q[:, None] >= o_k[None, :]
|
||||
d_s = tl.where(m_s, tl.math.exp2(
|
||||
(o_q[:, None] - o_k[None, :]) * b_b), 0)
|
||||
b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_s * scale
|
||||
# [BTL, BK]
|
||||
b_dq += tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
|
||||
p_k = tl.advance(p_k, (BTS, 0))
|
||||
p_v = tl.advance(p_v, (0, BTS))
|
||||
o_k += BTS
|
||||
p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _parallel_retention_bwd_dkv(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale,
|
||||
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
|
||||
DK: tl.constexpr, DV: tl.constexpr,
|
||||
):
|
||||
# no overlap. no need for mask.
|
||||
b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
# overall decay rate for an entire block
|
||||
d_b = tl.math.exp2(b_b * BTS)
|
||||
# compute dk dv
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),
|
||||
(i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),
|
||||
(i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
|
||||
b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(
|
||||
p_v, boundary_check=(0, 1))
|
||||
b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
|
||||
[BTL, BV], dtype=tl.float32)
|
||||
d_h = tl.math.exp2((BTL - tl.arange(0, BTL)) * b_b)
|
||||
b_kd = (b_k * d_h[:, None]).to(b_k.dtype)
|
||||
d_q = tl.math.exp2(tl.arange(0, BTS) * b_b)
|
||||
for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1)) # [BV, BTS]
|
||||
b_do = (b_do * d_q[None, :]).to(b_do.dtype)
|
||||
|
||||
b_dv *= d_b
|
||||
b_s = tl.dot(b_kd.to(b_q.dtype), b_q, allow_tf32=False) # [BTL, BTS]
|
||||
b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
|
||||
|
||||
b_dk *= d_b
|
||||
b_ds = tl.dot(b_v, b_do, allow_tf32=False)
|
||||
b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
|
||||
b_dk *= d_h[:, None] * scale
|
||||
b_dv *= scale
|
||||
tl.debug_barrier()
|
||||
o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
|
||||
for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BK, BQ]
|
||||
m_s = o_k[:, None] <= o_q[None, :]
|
||||
d_s = tl.where(m_s, tl.math.exp2(
|
||||
(-o_k[:, None] + o_q[None, :]) * b_b.to(tl.float32)), 0) * scale
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s
|
||||
b_ds = tl.dot(b_v, b_do, allow_tf32=False) * d_s
|
||||
# [BK, BD]
|
||||
b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
|
||||
b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
|
||||
o_q += BTS
|
||||
p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,
|
||||
(T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,
|
||||
(T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def parallel_retention_bwd_kernel(
|
||||
q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale,
|
||||
BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
|
||||
DK: tl.constexpr, DV: tl.constexpr,
|
||||
):
|
||||
i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
NV = tl.cdiv(DV, BV)
|
||||
i_k = i_kv // (NV)
|
||||
i_v = i_kv % (NV)
|
||||
i_h = i_bh % H
|
||||
_parallel_retention_bwd_dq(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV
|
||||
)
|
||||
tl.debug_barrier()
|
||||
_parallel_retention_bwd_dkv(
|
||||
i_bh, i_c, i_k, i_v, i_h,
|
||||
q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,
|
||||
s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV
|
||||
)
|
||||
|
||||
|
||||
class ParallelRetentionFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v):
|
||||
BTL, BTS = 128, 32
|
||||
assert BTL % BTS == 0
|
||||
BK = min(128, triton.next_power_of_2(k.shape[-1]))
|
||||
BV = min(128, triton.next_power_of_2(v.shape[-1]))
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
num_stages = 3 if d_head_qk <= 64 else 2
|
||||
num_warps = 4
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
NV = triton.cdiv(d_head_v, BV)
|
||||
|
||||
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
|
||||
scale = d_head_qk ** -0.5
|
||||
o = torch.empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, dtype=q.dtype, device=q.device)
|
||||
parallel_retention_fwd_kernel[grid](
|
||||
q, k, v, o,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
ctx.save_for_backward(q, k, v)
|
||||
return o.sum(0).to(q.dtype)
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do):
|
||||
q, k, v = ctx.saved_tensors
|
||||
BTL, BTS = 64, 32
|
||||
assert BTL % BTS == 0
|
||||
BK = min(128, triton.next_power_of_2(k.shape[-1]))
|
||||
BV = min(128, triton.next_power_of_2(v.shape[-1]))
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
num_stages = 3 if d_head_qk <= 64 else 2
|
||||
num_warps = 4
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
NV = triton.cdiv(d_head_v, BV)
|
||||
grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)
|
||||
scale = d_head_qk ** -0.5
|
||||
|
||||
dq = torch.empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=q.dtype, device=q.device)
|
||||
dk = torch.empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=q.dtype, device=q.device)
|
||||
dv = torch.empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, dtype=q.dtype, device=q.device)
|
||||
|
||||
parallel_retention_bwd_kernel[grid](
|
||||
q, k, v, do, dq, dk, dv,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype)
|
||||
|
||||
|
||||
parallel_retention = ParallelRetentionFunction.apply
|
||||
281
finetune/lora/v6/fla/ops/retention/recurrent_fuse.py
vendored
Normal file
281
finetune/lora/v6/fla/ops/retention/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,281 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_retention_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
initial_state,
|
||||
final_state, # final hidden state [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
|
||||
# decay rate given the head index
|
||||
b_b = (1 - tl.math.pow(2, -5 - i_h * 1.0))
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
|
||||
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
|
||||
mask_kv = mask_bk[None, :] & mask_bv[:, None]
|
||||
|
||||
h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
|
||||
h = b_b * h + _k[None, :] * _v[:, None]
|
||||
_o = h * _q[None, :]
|
||||
_o = tl.sum(_o, axis=1)
|
||||
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
p_q += DK
|
||||
p_k += DK
|
||||
p_o += DV
|
||||
p_v += DV
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final_s = final_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_recurrent_retention_bwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
|
||||
# initial hidden state initialization [B, H, D_head_K, D_head_V]
|
||||
initial_state,
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
|
||||
b_b = 1 - tl.math.pow(2, -5 - i_h * 1.0)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < DK
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < DV
|
||||
|
||||
h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
mask_kv = mask_bk[:, None] & mask_bv[None, :]
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[:, None]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[None, :])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for i in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
|
||||
h = b_b * h + _k[:, None] * _v[None, :]
|
||||
_d_q = h * _do[None, :]
|
||||
d_q = tl.sum(_d_q, axis=1) * scale
|
||||
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
|
||||
|
||||
p_k += DK
|
||||
p_do += DV
|
||||
p_v += DV
|
||||
p_dq += DK
|
||||
|
||||
# sync threads
|
||||
tl.debug_barrier()
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV
|
||||
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
|
||||
BK + tl.arange(0, BK) + (T - 1) * DK
|
||||
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
|
||||
BV + tl.arange(0, BV) + (T - 1) * DV
|
||||
d_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
for _ in range(T):
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
d_h += _q[:, None] * _do[None, :]
|
||||
d_k = tl.sum(d_h * _v[None, :], axis=1)
|
||||
d_v = tl.sum(d_h * _k[:, None], axis=0)
|
||||
|
||||
d_h *= b_b
|
||||
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
p_do -= DV
|
||||
p_q -= DK
|
||||
p_k -= DK
|
||||
p_v -= DV
|
||||
p_dk -= DK
|
||||
p_dv -= DV
|
||||
|
||||
|
||||
class FusedRecurrentRetentionFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, initial_state=None, output_final_state=False):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
|
||||
scale = d_head_qk ** -0.5
|
||||
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
|
||||
else:
|
||||
final_state = None
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_recurrent_retention_fwd_kernel[grid](
|
||||
q, k, v, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
ctx.save_for_backward(q, k, v, initial_state)
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, initial_state = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = d_head_qk ** -0.5
|
||||
|
||||
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
|
||||
fused_recurrent_retention_bwd_kernel[grid](
|
||||
q, k, v, do, dq, dk, dv, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
return dq, dk, dv, None, None
|
||||
|
||||
|
||||
# fused_recurrent_retention = FusedRecurrentRetentionFunction.apply
|
||||
|
||||
def fused_recurrent_retention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = FusedRecurrentRetentionFunction.apply(q, k, v, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
252
finetune/lora/v6/fla/ops/rotary.py
vendored
Normal file
252
finetune/lora/v6/fla/ops/rotary.py
vendored
Normal file
@@ -0,0 +1,252 @@
|
||||
# Copyright (c) 2023, Tri Dao. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_M": 2}),
|
||||
# triton.Config({"BLOCK_M": 4}),
|
||||
# triton.Config({"BLOCK_M": 8}),
|
||||
# triton.Config({"BLOCK_M": 16}),
|
||||
# ],
|
||||
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
|
||||
# )
|
||||
@triton.jit
|
||||
def rotary_kernel(
|
||||
OUT, # Pointers to matrices
|
||||
X,
|
||||
COS,
|
||||
SIN,
|
||||
CU_SEQLENS,
|
||||
SEQLEN_OFFSETS, # this could be int or a pointer
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
CACHE_KEY_SEQLEN,
|
||||
# strides
|
||||
stride_out_batch,
|
||||
stride_out_seqlen,
|
||||
stride_out_nheads,
|
||||
stride_out_headdim,
|
||||
stride_x_batch,
|
||||
stride_x_seqlen,
|
||||
stride_x_nheads,
|
||||
stride_x_headdim,
|
||||
# Meta-parameters
|
||||
BLOCK_K: tl.constexpr,
|
||||
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
INTERLEAVED: tl.constexpr,
|
||||
CONJUGATE: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_batch = tl.program_id(axis=1)
|
||||
pid_head = tl.program_id(axis=2)
|
||||
rotary_dim_half = rotary_dim // 2
|
||||
|
||||
if not IS_VARLEN:
|
||||
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
||||
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
||||
else:
|
||||
start_idx = tl.load(CU_SEQLENS + pid_batch)
|
||||
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
|
||||
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
|
||||
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
|
||||
|
||||
if pid_m * BLOCK_M >= seqlen:
|
||||
return
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
if not IS_SEQLEN_OFFSETS_TENSOR:
|
||||
rm_cs = rm + SEQLEN_OFFSETS
|
||||
else:
|
||||
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
rk_half = tl.arange(0, BLOCK_K // 2)
|
||||
|
||||
if not INTERLEAVED:
|
||||
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
|
||||
X = X + (rm[:, None] * stride_x_seqlen +
|
||||
rk_half[None, :] * stride_x_headdim)
|
||||
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
||||
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
||||
cos = tl.load(
|
||||
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
|
||||
).to(tl.float32)
|
||||
sin = tl.load(
|
||||
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
||||
).to(tl.float32)
|
||||
x0 = tl.load(
|
||||
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
||||
).to(tl.float32)
|
||||
x1 = tl.load(
|
||||
X + rotary_dim_half * stride_x_headdim,
|
||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if CONJUGATE:
|
||||
sin = -sin
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
# write back result
|
||||
OUT = OUT + (rm[:, None] * stride_out_seqlen +
|
||||
rk_half[None, :] * stride_out_headdim)
|
||||
tl.store(OUT, o0, mask=(rm[:, None] < seqlen)
|
||||
& (rk_half[None, :] < rotary_dim_half))
|
||||
tl.store(
|
||||
OUT + rotary_dim_half * stride_out_headdim,
|
||||
o1,
|
||||
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
||||
)
|
||||
else:
|
||||
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
|
||||
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
|
||||
# Loading x0 will be fast but x1 will be slow.
|
||||
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
|
||||
# Then we do the calculation and use tl.where to pick put the right outputs for the even
|
||||
# and for the odd indices.
|
||||
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
|
||||
rk_repeat = tl.arange(0, BLOCK_K) // 2
|
||||
X0 = X + (rm[:, None] * stride_x_seqlen +
|
||||
rk[None, :] * stride_x_headdim)
|
||||
X1 = X + (rm[:, None] * stride_x_seqlen +
|
||||
rk_swap[None, :] * stride_x_headdim)
|
||||
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
||||
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
||||
cos = tl.load(
|
||||
COS,
|
||||
mask=(rm_cs[:, None] < seqlen_ro) & (
|
||||
rk_repeat[None, :] < rotary_dim_half),
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
sin = tl.load(
|
||||
SIN,
|
||||
mask=(rm_cs[:, None] < seqlen_ro) & (
|
||||
rk_repeat[None, :] < rotary_dim_half),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
x1 = tl.load(
|
||||
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
||||
).to(tl.float32)
|
||||
if CONJUGATE:
|
||||
sin = -sin
|
||||
x0_cos = x0 * cos
|
||||
x1_sin = x1 * sin
|
||||
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
|
||||
OUT = OUT + (rm[:, None] * stride_out_seqlen +
|
||||
rk[None, :] * stride_out_headdim)
|
||||
tl.store(OUT, out, mask=(rm[:, None] < seqlen)
|
||||
& (rk[None, :] < rotary_dim))
|
||||
|
||||
|
||||
def apply_rotary(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
interleaved=False,
|
||||
inplace=False,
|
||||
conjugate=False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
|
||||
else (total_seqlen, nheads, headdim).
|
||||
cos: (seqlen_ro, rotary_dim / 2)
|
||||
sin: (seqlen_ro, rotary_dim / 2)
|
||||
seqlen_offsets: integer or integer tensor of size (batch,)
|
||||
cu_seqlens: (batch + 1,) or None
|
||||
max_seqlen: int
|
||||
Returns:
|
||||
y: (batch, seqlen, nheads, headdim)
|
||||
"""
|
||||
is_varlen = cu_seqlens is not None
|
||||
if not is_varlen:
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
else:
|
||||
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
batch_p_1 = cu_seqlens.shape[0]
|
||||
batch = batch_p_1 - 1
|
||||
seqlen = max_seqlen
|
||||
seqlen_ro, rotary_dim = cos.shape
|
||||
assert sin.shape == cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
|
||||
assert headdim <= 256, "Only support headdim <= 256"
|
||||
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
|
||||
|
||||
assert (
|
||||
cos.dtype == sin.dtype
|
||||
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
|
||||
assert (
|
||||
x.dtype == cos.dtype
|
||||
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
|
||||
|
||||
cos, sin = cos.contiguous(), sin.contiguous()
|
||||
if isinstance(seqlen_offsets, torch.Tensor):
|
||||
assert seqlen_offsets.shape == (batch,)
|
||||
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
|
||||
seqlen_offsets = seqlen_offsets.contiguous()
|
||||
else:
|
||||
assert seqlen_offsets + seqlen <= seqlen_ro
|
||||
|
||||
output = torch.empty_like(x) if not inplace else x
|
||||
if rotary_dim < headdim and not inplace:
|
||||
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
||||
|
||||
BLOCK_K = (
|
||||
32
|
||||
if rotary_dim <= 32
|
||||
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
|
||||
)
|
||||
def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
|
||||
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
|
||||
|
||||
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
||||
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
||||
with torch.cuda.device(x.device.index):
|
||||
rotary_kernel[grid](
|
||||
output, # data ptrs
|
||||
x,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
seqlen_offsets,
|
||||
seqlen, # shapes
|
||||
nheads,
|
||||
rotary_dim,
|
||||
seqlen_ro,
|
||||
# key for triton cache (limit number of compilations)
|
||||
seqlen // 128,
|
||||
# batch_strides if not varlen else 0
|
||||
output.stride(0) if not is_varlen else 0,
|
||||
output.stride(-3), # seqlen_stride or total_seqlen_stride
|
||||
output.stride(-2), # nheads_stride
|
||||
output.stride(-1), # headdim_stride
|
||||
# batch_strides if not varlen else 0
|
||||
x.stride(0) if not is_varlen else 0,
|
||||
x.stride(-3), # seqlen stride or total_seqlen_stride
|
||||
x.stride(-2), # nheads stride
|
||||
x.stride(-1), # headdim stride
|
||||
BLOCK_K,
|
||||
isinstance(seqlen_offsets, torch.Tensor),
|
||||
is_varlen,
|
||||
interleaved,
|
||||
conjugate,
|
||||
BLOCK_M,
|
||||
)
|
||||
return output
|
||||
7
finetune/lora/v6/fla/ops/rwkv4/__init__.py
vendored
Normal file
7
finetune/lora/v6/fla/ops/rwkv4/__init__.py
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .recurrent_fuse import fused_recurrent_rwkv4
|
||||
|
||||
__all__ = [
|
||||
'fused_recurrent_rwkv4'
|
||||
]
|
||||
484
finetune/lora/v6/fla/ops/rwkv4/recurrent_fuse.py
vendored
Normal file
484
finetune/lora/v6/fla/ops/rwkv4/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,484 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# adopted from https://github.com/codekansas/rwkv
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch import Tensor
|
||||
from torch.autograd.function import Function, FunctionCtx, once_differentiable
|
||||
|
||||
|
||||
def get_block_size_c(chans: int) -> int:
|
||||
if chans < 32:
|
||||
return 32
|
||||
if chans < 64:
|
||||
return 64
|
||||
return 128
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_rwkv4_forward_kernel(
|
||||
# W
|
||||
w_ptr,
|
||||
w_s_c,
|
||||
# U
|
||||
u_ptr,
|
||||
u_s_c,
|
||||
# K
|
||||
k_ptr,
|
||||
k_s_b,
|
||||
k_s_t,
|
||||
k_s_c,
|
||||
# V
|
||||
v_ptr,
|
||||
v_s_b,
|
||||
v_s_t,
|
||||
v_s_c,
|
||||
# State
|
||||
state_ptr,
|
||||
state_s_b,
|
||||
state_s_abe,
|
||||
state_s_c,
|
||||
# WKV
|
||||
wkv_ptr,
|
||||
wkv_s_b,
|
||||
wkv_s_t,
|
||||
wkv_s_c,
|
||||
# Output state
|
||||
state_out_ptr,
|
||||
state_out_s_b,
|
||||
state_out_s_abe,
|
||||
state_out_s_t,
|
||||
state_out_s_c,
|
||||
# Params
|
||||
chans,
|
||||
tsz,
|
||||
BLOCK_SIZE_C: tl.constexpr,
|
||||
):
|
||||
# Parallelize over the batch dimension.
|
||||
b_idx = tl.program_id(0)
|
||||
c_idx = tl.program_id(1)
|
||||
|
||||
cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
|
||||
cmask = cs < chans
|
||||
|
||||
# Pointers to the batch (and possibly channel) for the input tensors.
|
||||
k_ptr = k_ptr + b_idx * k_s_b
|
||||
v_ptr = v_ptr + b_idx * v_s_b
|
||||
alpha_ptr = state_ptr + b_idx * state_s_b
|
||||
beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
|
||||
eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
|
||||
|
||||
# Pointers to the batch (and possibly channel) for the output tensors.
|
||||
wkv_ptr = wkv_ptr + b_idx * wkv_s_b
|
||||
alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b
|
||||
beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe
|
||||
eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe
|
||||
|
||||
# Loads parameters.
|
||||
alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
|
||||
beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
|
||||
eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
|
||||
w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)
|
||||
u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)
|
||||
|
||||
for t in range(tsz):
|
||||
kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)
|
||||
vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)
|
||||
|
||||
ukt = u + kt
|
||||
tau = tl.maximum(ukt, eps)
|
||||
e1a = tl.exp(eps - tau)
|
||||
e2a = tl.exp(ukt - tau)
|
||||
wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)
|
||||
tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)
|
||||
|
||||
w_eps = w + eps
|
||||
eps = tl.maximum(w_eps, kt)
|
||||
e1b = tl.exp(w_eps - eps)
|
||||
e2b = tl.exp(kt - eps)
|
||||
alpha = e1b * alpha + e2b * vt
|
||||
beta = e1b * beta + e2b
|
||||
tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)
|
||||
tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)
|
||||
tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)
|
||||
|
||||
|
||||
def fused_recurrent_rwkv4_forward(
|
||||
w: Tensor,
|
||||
u: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
state: Tensor,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
(bsz, tsz, chans) = k.shape
|
||||
|
||||
# New tensors to output.
|
||||
wkvs = k.new_empty(bsz, tsz, chans)
|
||||
state_out = k.new_empty(bsz, 3, tsz, chans)
|
||||
|
||||
# Constants.
|
||||
block_size_c = get_block_size_c(chans)
|
||||
|
||||
def grid(meta: dict[str, Any]) -> tuple[int, ...]:
|
||||
return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
|
||||
|
||||
fused_recurrent_rwkv4_forward_kernel[grid](
|
||||
# W
|
||||
w,
|
||||
w.stride(0),
|
||||
# U
|
||||
u,
|
||||
u.stride(0),
|
||||
# K
|
||||
k,
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
# V
|
||||
v,
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
# State
|
||||
state,
|
||||
state.stride(0),
|
||||
state.stride(1),
|
||||
state.stride(3),
|
||||
# WKV
|
||||
wkvs,
|
||||
wkvs.stride(0),
|
||||
wkvs.stride(1),
|
||||
wkvs.stride(2),
|
||||
# Output state
|
||||
state_out,
|
||||
state_out.stride(0),
|
||||
state_out.stride(1),
|
||||
state_out.stride(2),
|
||||
state_out.stride(3),
|
||||
# Params
|
||||
chans,
|
||||
tsz,
|
||||
BLOCK_SIZE_C=block_size_c,
|
||||
)
|
||||
|
||||
state_out = torch.cat((state, state_out), dim=2)
|
||||
|
||||
return wkvs, state_out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_rwkv4_backward_kernel(
|
||||
# W
|
||||
w_ptr,
|
||||
w_s_c,
|
||||
# U
|
||||
u_ptr,
|
||||
u_s_c,
|
||||
# K
|
||||
k_ptr,
|
||||
k_s_b,
|
||||
k_s_t,
|
||||
k_s_c,
|
||||
# V
|
||||
v_ptr,
|
||||
v_s_b,
|
||||
v_s_t,
|
||||
v_s_c,
|
||||
# State
|
||||
state_ptr,
|
||||
state_s_b,
|
||||
state_s_abe,
|
||||
state_s_t,
|
||||
state_s_c,
|
||||
# WKV grad
|
||||
gwkv_ptr,
|
||||
gwkv_s_b,
|
||||
gwkv_s_t,
|
||||
gwkv_s_c,
|
||||
# Output state grad
|
||||
gstate_out_ptr,
|
||||
gstate_out_s_b,
|
||||
gstate_out_s_abe,
|
||||
gstate_out_s_c,
|
||||
# W grad
|
||||
gw_ptr,
|
||||
gw_s_c,
|
||||
# U grad
|
||||
gu_ptr,
|
||||
gu_s_c,
|
||||
# K grad
|
||||
gk_ptr,
|
||||
gk_s_b,
|
||||
gk_s_t,
|
||||
gk_s_c,
|
||||
# V grad
|
||||
gv_ptr,
|
||||
gv_s_b,
|
||||
gv_s_t,
|
||||
gv_s_c,
|
||||
# State grad
|
||||
gstate_ptr,
|
||||
gstate_s_b,
|
||||
gstate_s_abe,
|
||||
gstate_s_c,
|
||||
# Params
|
||||
tsz,
|
||||
chans,
|
||||
BLOCK_SIZE_C: tl.constexpr,
|
||||
):
|
||||
# Parallelize over the batch dimension.
|
||||
b_idx = tl.program_id(0)
|
||||
c_idx = tl.program_id(1)
|
||||
|
||||
cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
|
||||
cmask = cs < chans
|
||||
|
||||
# Pointers to the batch (and possibly channel) for the input tensors.
|
||||
k_ptr = k_ptr + b_idx * k_s_b
|
||||
v_ptr = v_ptr + b_idx * v_s_b
|
||||
alpha_ptr = state_ptr + b_idx * state_s_b
|
||||
beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
|
||||
eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
|
||||
|
||||
# Pointers to the batch (and possibly channel) for the output tensors.
|
||||
gk_ptr = gk_ptr + b_idx * gk_s_b
|
||||
gv_ptr = gv_ptr + b_idx * gv_s_b
|
||||
|
||||
# Pointers to gradients which were recieved by the function.
|
||||
gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b
|
||||
galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b
|
||||
gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe
|
||||
geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe
|
||||
|
||||
# Loads parameters.
|
||||
galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
|
||||
gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
|
||||
geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
|
||||
w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)
|
||||
u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)
|
||||
|
||||
# Gradient accumulators.
|
||||
gw = tl.zeros_like(w)
|
||||
gu = tl.zeros_like(u)
|
||||
|
||||
alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
||||
beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
||||
eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
||||
|
||||
for t in range(tsz):
|
||||
tc = tsz - t - 1
|
||||
|
||||
kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)
|
||||
vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)
|
||||
|
||||
alpha_curr = alpha_prev
|
||||
beta_curr = beta_prev
|
||||
eps_curr = eps_prev
|
||||
|
||||
alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
||||
beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
||||
eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
||||
|
||||
ukt = u + kt
|
||||
tau = tl.maximum(ukt, eps_prev)
|
||||
e1 = tl.exp(eps_prev - tau)
|
||||
e2 = tl.exp(ukt - tau)
|
||||
|
||||
euke = tl.exp(ukt + eps_prev - 2 * tau)
|
||||
|
||||
denom = e1 * beta_prev + e2
|
||||
denom_sq = denom * denom
|
||||
|
||||
gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)
|
||||
|
||||
# Backpropagates wkv gradients.
|
||||
guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq
|
||||
gu += guk
|
||||
gk = guk
|
||||
gv = gwkvt * e2 / denom
|
||||
|
||||
galpha_wkv = gwkvt * e1 / denom
|
||||
gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq
|
||||
geps_wkv_denom = e1 * beta_prev + e2
|
||||
geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)
|
||||
|
||||
e1 = tl.exp(w + eps_prev - eps_curr)
|
||||
e2 = tl.exp(kt - eps_curr)
|
||||
|
||||
# Backpropagates alpha gradients.
|
||||
galpha_we = galpha * e1 * alpha_prev
|
||||
gw += galpha_we
|
||||
gk += galpha * e2 * vt
|
||||
gv += galpha * e2
|
||||
geps += galpha * -alpha_curr
|
||||
|
||||
# Backpropagates beta gradients.
|
||||
gbeta_we = gbeta * e1 * beta_prev
|
||||
gw += gbeta_we
|
||||
gk += gbeta * e2
|
||||
geps += gbeta * -beta_curr
|
||||
|
||||
# Backpropagates epsilon gradients.
|
||||
geps_mask = w + eps_prev > kt
|
||||
geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))
|
||||
gw += geps_we
|
||||
gk += tl.where(geps_mask, tl.zeros_like(geps), geps)
|
||||
|
||||
# Stores the gradients for k and v.
|
||||
tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)
|
||||
tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)
|
||||
|
||||
# Computes new gradients for alpha and beta.
|
||||
galpha = galpha * e1 + galpha_wkv
|
||||
gbeta = gbeta * e1 + gbeta_wkv
|
||||
geps = galpha_we + gbeta_we + geps_we + geps_wkv
|
||||
|
||||
# Stores final gradients for alpha and beta.
|
||||
galpha_ptr = gstate_ptr + b_idx * gstate_s_b
|
||||
gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe
|
||||
geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe
|
||||
tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)
|
||||
tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)
|
||||
tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)
|
||||
|
||||
# Stores final gradients for w and u.
|
||||
gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)
|
||||
gw_temp += gw
|
||||
tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)
|
||||
gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)
|
||||
gu_temp += gu
|
||||
tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)
|
||||
|
||||
|
||||
def fused_recurrent_rwkv4_backward(
|
||||
w: Tensor,
|
||||
u: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
state: Tensor,
|
||||
grad_wkv: Tensor,
|
||||
grad_state: Tensor,
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
bsz, tsz, chans = k.shape
|
||||
|
||||
gw = torch.zeros_like(w) # New tensors to output.
|
||||
gu = torch.zeros_like(u)
|
||||
gk = torch.empty_like(k)
|
||||
gv = torch.empty_like(v)
|
||||
gstate = k.new_empty(bsz, 3, 1, chans)
|
||||
|
||||
block_size_c = get_block_size_c(chans) # Constants.
|
||||
|
||||
def grid(meta: dict[str, Any]) -> tuple[int, ...]:
|
||||
return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
|
||||
|
||||
fused_recurrent_rwkv4_backward_kernel[grid](
|
||||
# W
|
||||
w,
|
||||
w.stride(0),
|
||||
# U
|
||||
u,
|
||||
u.stride(0),
|
||||
# K
|
||||
k,
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
# V
|
||||
v,
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
# State
|
||||
state,
|
||||
state.stride(0),
|
||||
state.stride(1),
|
||||
state.stride(2),
|
||||
state.stride(3),
|
||||
# WKV grad
|
||||
grad_wkv,
|
||||
grad_wkv.stride(0),
|
||||
grad_wkv.stride(1),
|
||||
grad_wkv.stride(2),
|
||||
# Output state grad
|
||||
grad_state,
|
||||
grad_state.stride(0),
|
||||
grad_state.stride(1),
|
||||
grad_state.stride(3),
|
||||
# W grad
|
||||
gw,
|
||||
gw.stride(0),
|
||||
# U grad
|
||||
gu,
|
||||
gu.stride(0),
|
||||
# K grad
|
||||
gk,
|
||||
gk.stride(0),
|
||||
gk.stride(1),
|
||||
gk.stride(2),
|
||||
# V grad
|
||||
gv,
|
||||
gv.stride(0),
|
||||
gv.stride(1),
|
||||
gv.stride(2),
|
||||
# State grad
|
||||
gstate,
|
||||
gstate.stride(0),
|
||||
gstate.stride(1),
|
||||
gstate.stride(3),
|
||||
# Params
|
||||
tsz,
|
||||
chans,
|
||||
BLOCK_SIZE_C=block_size_c,
|
||||
)
|
||||
|
||||
return gw, gu, gk, gv, gstate
|
||||
|
||||
|
||||
class FusedRecurrentRWKV4Function(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: FunctionCtx,
|
||||
w: Tensor,
|
||||
u: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
state: Tensor,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
ctx.input_dtype = k.dtype
|
||||
|
||||
if (
|
||||
w.device.type != "cuda"
|
||||
or u.device.type != "cuda"
|
||||
or k.device.type != "cuda"
|
||||
or v.device.type != "cuda"
|
||||
):
|
||||
raise ValueError(
|
||||
"Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices."
|
||||
)
|
||||
|
||||
w = -torch.exp(w.float().contiguous())
|
||||
if k.dtype == torch.float16:
|
||||
u = u.float()
|
||||
k = k.float()
|
||||
v = v.float()
|
||||
u = u.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state)
|
||||
ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1])
|
||||
return wkv, state_out[:, :, -1:]
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors)
|
||||
gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate)
|
||||
return gw, gu, gk, gv, gstate
|
||||
|
||||
|
||||
def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:
|
||||
return FusedRecurrentRWKV4Function.apply(w, u, k, v, state)
|
||||
9
finetune/lora/v6/fla/ops/rwkv6/__init__.py
vendored
Normal file
9
finetune/lora/v6/fla/ops/rwkv6/__init__.py
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_rwkv6
|
||||
from .recurrent_fuse import fused_recurrent_rwkv6
|
||||
|
||||
__all__ = [
|
||||
'chunk_rwkv6',
|
||||
'fused_recurrent_rwkv6'
|
||||
]
|
||||
921
finetune/lora/v6/fla/ops/rwkv6/chunk.py
vendored
Normal file
921
finetune/lora/v6/fla/ops/rwkv6/chunk.py
vendored
Normal file
@@ -0,0 +1,921 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.ops.utils import chunk_reversed_cumsum_fwd
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BS': 16}, num_warps=2),
|
||||
triton.Config({'BS': 16}, num_warps=4),
|
||||
triton.Config({'BS': 16}, num_warps=8),
|
||||
triton.Config({'BS': 32}, num_warps=2),
|
||||
triton.Config({'BS': 32}, num_warps=4),
|
||||
triton.Config({'BS': 32}, num_warps=8),
|
||||
triton.Config({'BS': 64}, num_warps=2),
|
||||
triton.Config({'BS': 64}, num_warps=4),
|
||||
triton.Config({'BS': 64}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_rwkv6_fwd_kernel_cum(
|
||||
s,
|
||||
o,
|
||||
o_minus_s,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def post_process_grad(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
u,
|
||||
do,
|
||||
dk,
|
||||
dq,
|
||||
du,
|
||||
scale,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
H,
|
||||
T: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_h = i_bh % H
|
||||
|
||||
# Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V)
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))
|
||||
p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
b_u = tl.load(p_u, boundary_check=(0,))
|
||||
|
||||
b_vdo = tl.sum(b_v * b_do, axis=1)
|
||||
b_du = b_vdo[:, None] * b_k * b_q * scale
|
||||
b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale
|
||||
b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale
|
||||
|
||||
b_dq += tl.load(p_dq, boundary_check=(0, 1))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
b_dk += tl.load(p_dk, boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
for i_t in range(NT):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
if i_t < NT - 1:
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
else:
|
||||
b_gn = tl.min(b_g, axis=1)
|
||||
b_h *= tl.exp(b_gn)[:, None]
|
||||
b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
|
||||
b_h += tl.dot(b_k, b_v, allow_tf32=False)
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_fwd_kernel_intra(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
gs,
|
||||
u,
|
||||
A,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
scale,
|
||||
H,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
NC: tl.constexpr,
|
||||
DK: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
|
||||
i_h = i_bh % H
|
||||
n_bh = tl.num_programs(2)
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_q = i_t * BT + i_i * BC
|
||||
m_k = o_k < K
|
||||
|
||||
if i_i > i_j:
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
|
||||
p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||
# [BK,]
|
||||
b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0)
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
b_qg = (b_q * tl.exp(b_gs - b_gn[None, :]) * scale).to(b_q.dtype)
|
||||
# [BK, BC]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
|
||||
# [BC, BC]
|
||||
b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
|
||||
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
|
||||
elif i_i == i_j:
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
p_q_self = tl.make_block_ptr(q + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
o_i = tl.arange(0, BC)
|
||||
o_g = i_bh * T * K + (i_t * BT + i_j * BC) * K + o_k
|
||||
o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
|
||||
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
|
||||
p_u = tl.make_block_ptr(u + i_h * DK, (DK,), (1,), (i_k * BK), (BK,), (0,))
|
||||
b_u = tl.load(p_u, boundary_check=(0,))
|
||||
for j in range(0, BC):
|
||||
# [BK,]
|
||||
b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
|
||||
b_gk = tl.load(g + o_g + j * K, mask=(m_k & ((i_t * BT + i_j * BC + j) < T)), other=0).to(tl.float32)
|
||||
# [BC,]
|
||||
b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_gs - b_gk[None, :]) * scale, 1)
|
||||
b_A = tl.where(o_i > j, b_A, 0.)
|
||||
# self
|
||||
b_q_self = tl.load(p_q_self, boundary_check=(0,)).to(tl.float32)
|
||||
A_self = tl.sum(b_q_self * b_k * b_u * scale, axis=0)
|
||||
m_self = tl.arange(0, BC) == j
|
||||
b_A = tl.where(m_self, A_self[None], b_A)
|
||||
tl.store(A + o_A + j, b_A.to(A.dtype.element_ty), mask=m_A)
|
||||
p_k = tl.advance(p_k, (K,))
|
||||
p_q_self = tl.advance(p_q_self, (K,))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_fwd_kernel_inter(
|
||||
q,
|
||||
v,
|
||||
gs,
|
||||
h,
|
||||
o,
|
||||
A,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_qg = (b_q * tl.exp(b_gs)).to(b_q.dtype)
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# works but dkw, owing to divine benevolence
|
||||
# [BT, BV]
|
||||
if i_k >= 0:
|
||||
b_o += tl.dot(b_qg, b_h, allow_tf32=False)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_o += tl.dot(b_A, b_v, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_bwd_kernel_dh(
|
||||
q,
|
||||
g,
|
||||
gs,
|
||||
do,
|
||||
dh,
|
||||
dh0,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BK, BV]
|
||||
b_dh *= tl.exp(b_gn)[:, None]
|
||||
# [BK, BT]
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
b_q = (b_q * tl.exp(b_gs)).to(b_q.dtype)
|
||||
|
||||
# [BK, BV]
|
||||
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_dh0 = tl.make_block_ptr(dh0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_bwd_kernel_inter(
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
gs,
|
||||
A,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
dA,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
n_bh = tl.num_programs(2)
|
||||
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gq = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
|
||||
|
||||
# [BT, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_gq = tl.load(p_gq, boundary_check=(0, 1))
|
||||
b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
|
||||
b_k = (b_k * b_gn).to(b_k.dtype)
|
||||
# [BT, BT]
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BV]
|
||||
b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
|
||||
if i_k == 0:
|
||||
b_dv += tl.dot(b_A, b_do, allow_tf32=False)
|
||||
b_do = (b_do * scale).to(b_do.dtype)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
|
||||
b_dq = b_dq * tl.exp(b_gq)
|
||||
b_dk = b_dk * b_gn
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] > o_i[None, :]
|
||||
# [BT, BT]
|
||||
b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
|
||||
if i_k == 0:
|
||||
tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_bwd_kernel_intra(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
gs,
|
||||
dA,
|
||||
dq,
|
||||
dk,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
NC: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_t, i_i = i_c // NC, i_c % NC
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_q = i_t * BT + i_i * BC
|
||||
m_k = o_k < K
|
||||
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
# [BK,]
|
||||
b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0)
|
||||
# [BC, BK]
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
b_dq = tl.zeros([BC, BK], dtype=tl.float32)
|
||||
for i_j in range(0, i_i):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||
# [BC, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
|
||||
# [BC, BC]
|
||||
b_dA = tl.load(p_dA, boundary_check=(0, 1))
|
||||
# [BC, BK]
|
||||
b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
|
||||
b_dq *= tl.exp(b_gs - b_gn[None, :])
|
||||
|
||||
o_i = tl.arange(0, BC)
|
||||
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
|
||||
m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
|
||||
|
||||
for j in range(0, BC):
|
||||
p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
# [BC,]
|
||||
b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
|
||||
# [BK,]
|
||||
b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
|
||||
b_gkj = tl.load(g + i_bh * T * K + (o_q + j) * K + o_k, mask=(m_k & ((o_q + j) < T)), other=0)
|
||||
# [BC, BK]
|
||||
m_i = o_i[:, None] > j
|
||||
# [BC, BK]
|
||||
b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_gs - b_gkj[None, :]), 0.)
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
|
||||
b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
tl.debug_barrier()
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BC, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_dk = tl.zeros([BC, BK], dtype=tl.float32)
|
||||
for i_j in range(i_i + 1, NC):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
b_qg = (b_q * tl.exp(b_gs - b_gn[None, :])).to(b_q.dtype)
|
||||
# [BC, BC]
|
||||
b_dA = tl.load(p_dA, boundary_check=(0, 1))
|
||||
# [BC, BK]
|
||||
b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
|
||||
b_dk *= tl.exp(b_gn[None, :] - b_gk)
|
||||
|
||||
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
|
||||
for j in range(0, BC):
|
||||
p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
|
||||
p_gqj = tl.make_block_ptr(gs + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BC,]
|
||||
b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
|
||||
# [BK,]
|
||||
b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
|
||||
b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
|
||||
# [BC, BK]
|
||||
m_i = o_i[:, None] < j
|
||||
b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
|
||||
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class ChunkRWKV6Function(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level):
|
||||
q = r # alias
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT, BC = 64, 16
|
||||
BK = min(64, triton.next_power_of_2(K))
|
||||
BV = min(64, triton.next_power_of_2(V))
|
||||
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
|
||||
NK = triton.cdiv(K, BK)
|
||||
NV = triton.cdiv(V, BV)
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
num_stages = 1
|
||||
|
||||
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NV, NK, B * H)
|
||||
chunk_rwkv6_fwd_kernel_h[grid](
|
||||
k, v, g, h, h0, ht,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
STORE_FINAL_STATE=ht is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return h
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float)
|
||||
|
||||
g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)
|
||||
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
|
||||
# keep cummulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g_org = g_org.view(B, H, NT, BT, -1)
|
||||
# g = g_org.cumsum(-2).view(B, H, T, -1)
|
||||
# gs = g - g_org
|
||||
chunk_rwkv6_fwd_kernel_cum[grid](
|
||||
g_org, g, gs,
|
||||
g.stride(1), g.stride(2), g.stride(3),
|
||||
T=T, S=K, BT=BT
|
||||
)
|
||||
h = fwd_inner(
|
||||
q=q, k=k, v=v, g=g,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
h0=initial_state if initial_state is not None else None,
|
||||
ht=final_state if final_state is not None else None
|
||||
)
|
||||
A = q.new_zeros(NK, B, H, T, BT)
|
||||
grid = (NK, NT * NC * NC, B * H)
|
||||
chunk_rwkv6_fwd_kernel_intra[grid](
|
||||
q, k, g, gs, u, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
scale,
|
||||
H=H, T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, DK=K,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
A = A.sum(0, dtype=A.dtype)
|
||||
o = torch.empty_like(v)
|
||||
|
||||
grid = (NV, NT, B * H)
|
||||
chunk_rwkv6_fwd_kernel_inter[grid](
|
||||
q, v, gs, h, o, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
if checkpoint_level > 1:
|
||||
del h
|
||||
h, initial_state = None, None
|
||||
del g, gs
|
||||
ctx.save_for_backward(q, k, v, g_org, u, h, initial_state, A)
|
||||
ctx.BT = BT
|
||||
ctx.scale = scale
|
||||
ctx.checkpoint_level = checkpoint_level
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, dht=None):
|
||||
q, k, v, g, u, h, initial_state, A = ctx.saved_tensors
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT, BC = ctx.BT, 16
|
||||
BK = min(64, triton.next_power_of_2(K))
|
||||
BV = min(64, triton.next_power_of_2(V))
|
||||
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
|
||||
NK = triton.cdiv(K, BK)
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
num_stages = 1
|
||||
|
||||
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NV, NK, B * H)
|
||||
chunk_rwkv6_fwd_kernel_h[grid](
|
||||
k, v, g, h, h0, ht,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
STORE_FINAL_STATE=ht is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return h
|
||||
|
||||
def bwd_inner(q, g, gs, h0, do, B, H, T, K, V, BT, BK, BV, NT, scale):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
dh0 = torch.empty_like(h0) if h0 is not None else None
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_rwkv6_bwd_kernel_dh[grid](
|
||||
q, g, gs, do, dh, dh0,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
do.stride(1), do.stride(2), do.stride(3),
|
||||
dh.stride(1), dh.stride(2), dh.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return dh, dh0
|
||||
|
||||
# recompute cumulative log decays.
|
||||
g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)
|
||||
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
|
||||
# keep cummulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_rwkv6_fwd_kernel_cum[grid](
|
||||
g_org, g, gs,
|
||||
g.stride(1), g.stride(2), g.stride(3),
|
||||
T=T, S=K, BT=BT
|
||||
)
|
||||
|
||||
# rerun the forward pass to get h if checkpoint_level >= 1
|
||||
if ctx.checkpoint_level == 1:
|
||||
h = fwd_inner(
|
||||
q=q, k=k, v=v, g=g,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
h0=initial_state if initial_state is not None else None,
|
||||
ht=None
|
||||
)
|
||||
|
||||
scale = ctx.scale
|
||||
dh, dh0 = bwd_inner(
|
||||
q, g, gs, initial_state, do,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
scale=scale
|
||||
)
|
||||
dq = torch.empty_like(q, dtype=torch.float)
|
||||
dk = torch.empty_like(k, dtype=torch.float)
|
||||
dv = v.new_empty(NK, *v.shape)
|
||||
dA = q.new_zeros(B, H, T, BT)
|
||||
grid = (NK, NT, B * H)
|
||||
chunk_rwkv6_bwd_kernel_inter[grid](
|
||||
k, v, h, g, gs, A, do, dh, dq, dk, dv, dA,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dv = dv.sum(0, dtype=dv.dtype)
|
||||
grid = (NK, NT * NC, B * H)
|
||||
chunk_rwkv6_bwd_kernel_intra[grid](
|
||||
q, k, g, gs, dA, dq, dk,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
# TODO: fuse?
|
||||
dg = (dq * q)[:, :, 1:] - (dk * k)[:, :, 0:-1]
|
||||
dg = torch.nn.functional.pad(dg, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
|
||||
dg = chunk_reversed_cumsum_fwd(dg).to(g)
|
||||
# equivalent to the following pytorch code.
|
||||
# du = ((do * v).sum(-1)[..., None] * k * q * scale).sum(-2).to(u)
|
||||
# dq += ((do * v).sum(-1)[..., None] * k * scale * u[:, :, None, :])
|
||||
# dk += ((do * v).sum(-1)[..., None] * q * scale * u[:, :, None, :])
|
||||
BT = 64
|
||||
grid = (triton.cdiv(T, BT), B * H)
|
||||
du = torch.empty_like(g, dtype=torch.float)
|
||||
post_process_grad[grid](
|
||||
q, k, v, u, do, dk, dq, du, scale,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3), H=H,
|
||||
T=T, BT=BT, K=K, V=V, BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V),
|
||||
num_warps=4
|
||||
)
|
||||
du = du.sum([0, 2])
|
||||
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None
|
||||
|
||||
|
||||
def chunk_rwkv6(
|
||||
r: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
scale: Optional[int] = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
checkpoint_level: Optional[int] = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
r (torch.Tensor):
|
||||
reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
|
||||
k (torch.Tensor):
|
||||
keys of shape `(B, H, T, K)`
|
||||
v (torch.Tensor):
|
||||
values of shape `(B, H, T, V)`
|
||||
w (torch.Tensor):
|
||||
data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
|
||||
u (torch.Tensor):
|
||||
bonus of shape `(H, K)`
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RWKV6 attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `(B, H, K, V)`. Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
|
||||
checkpoint_level (Optional[int]):
|
||||
Checkpointing level; higher values will save more memories and do more recomputations during backward.
|
||||
Default: `0`:
|
||||
- Level `0`: store forward hidden states for backprop.
|
||||
- Level `1`: recompute the forward hidden states during backward.
|
||||
"""
|
||||
assert checkpoint_level in [0, 1]
|
||||
if scale is None:
|
||||
scale = r.shape[-1] ** -0.5
|
||||
o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level)
|
||||
return o, final_state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
|
||||
B = 4
|
||||
H = 4
|
||||
L = 1024
|
||||
K = 100
|
||||
V = 120
|
||||
|
||||
torch.manual_seed(0)
|
||||
dtype = torch.float32
|
||||
q = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True)
|
||||
k = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True)
|
||||
v = torch.randn(B, H, L, V).cuda().to(dtype).requires_grad_(True)
|
||||
w = (-torch.randn(B, H, L, K).exp()).cuda().to(torch.float32).requires_grad_(True)
|
||||
u = torch.randn(H, K).cuda().to(dtype).requires_grad_(True)
|
||||
h0 = torch.randn(B, H, K, V).cuda().to(dtype).requires_grad_(True)
|
||||
do = torch.rand_like(v).cuda()
|
||||
o, ht = fused_recurrent_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True)
|
||||
o.backward(do)
|
||||
dq, q.grad = q.grad.clone(), None
|
||||
dk, k.grad = k.grad.clone(), None
|
||||
dv, v.grad = v.grad.clone(), None
|
||||
dw, w.grad = w.grad.clone(), None
|
||||
du, u.grad = u.grad.clone(), None
|
||||
dh0, h0.grad = h0.grad.clone(), None
|
||||
o2, ht2 = chunk_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True)
|
||||
o2.backward(do)
|
||||
torch.testing.assert_close(o, o2, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(ht, ht2, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(q.grad, dq, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(k.grad, dk, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(v.grad, dv, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(w.grad, dw, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(u.grad, du, rtol=0, atol=2e-4)
|
||||
torch.testing.assert_close(h0.grad, dh0, rtol=0, atol=2e-4)
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
# argument names to use as an x-axis for the plot
|
||||
x_names=['T'],
|
||||
# different possible values for `x_name`
|
||||
x_vals=[128 * 2 ** i for i in range(0, 8)],
|
||||
# argument name whose value corresponds to a different line in the plot
|
||||
line_arg='provider',
|
||||
# possible values for `line_arg``
|
||||
line_vals=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'],
|
||||
# label name for the lines
|
||||
line_names=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
|
||||
ylabel="Execution Time (ms)", # label name for the y-axis
|
||||
# name for the plot. Used also as a file name for saving the plot.
|
||||
plot_name="Performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(T, provider):
|
||||
device = 'cuda'
|
||||
dtype = torch.bfloat16
|
||||
requires_grad = True
|
||||
B, H, K = 16, 4, 128
|
||||
|
||||
q = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
|
||||
k = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
|
||||
v = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
|
||||
w = F.logsigmoid(torch.randn(B, H, T, K)).to(dtype=dtype, device=device).requires_grad_(True)
|
||||
u = torch.randn(H, K, device=device, requires_grad=requires_grad, dtype=dtype)
|
||||
|
||||
do = torch.ones_like(q, dtype=dtype)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
results = 0, 0, 0
|
||||
if provider == 'recurrent':
|
||||
results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u), quantiles=quantiles)
|
||||
if provider == 'chunk':
|
||||
results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u), quantiles=quantiles)
|
||||
if provider == 'recurrent_bwd':
|
||||
results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u)
|
||||
[0].backward(do), quantiles=quantiles)
|
||||
if provider == 'chunk_bwd':
|
||||
results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u)[0].backward(do), quantiles=quantiles)
|
||||
return results
|
||||
benchmark.run(print_data=True)
|
||||
79
finetune/lora/v6/fla/ops/rwkv6/chunk_naive.py
vendored
Normal file
79
finetune/lora/v6/fla/ops/rwkv6/chunk_naive.py
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from fla.ops.rwkv6.chunk import chunk_rwkv6
|
||||
from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
|
||||
|
||||
|
||||
def naive_chunk_rwkv6(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
u,
|
||||
chunk_size=32,
|
||||
initial_state=None,
|
||||
output_final_state=True,
|
||||
):
|
||||
assert q.shape[-2] % chunk_size == 0
|
||||
orig_dtype = q.dtype
|
||||
num_chunk = q.shape[-2] // chunk_size
|
||||
u = u.unsqueeze(0)
|
||||
|
||||
q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w))
|
||||
|
||||
w_cumsum = w.cumsum(-2)
|
||||
|
||||
kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp()
|
||||
wkv = kw.transpose(-1, -2) @ v
|
||||
|
||||
wkv_new = torch.zeros_like(wkv)
|
||||
|
||||
for i in range(num_chunk - 1):
|
||||
wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i]
|
||||
|
||||
o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp()))
|
||||
|
||||
o_intra = torch.zeros_like(o_inter)
|
||||
for i in range(chunk_size):
|
||||
attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1)
|
||||
mask = (torch.arange(0, chunk_size) < i).to(attn.device)
|
||||
attn.masked_fill_(~mask, 0)
|
||||
intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2)
|
||||
intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i]
|
||||
o_intra[:, :, :, i] = intra_inter_o + intra_intra_o
|
||||
o = o_inter + o_intra
|
||||
return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
B = 4
|
||||
H = 4
|
||||
L = 1024
|
||||
D = 100
|
||||
dtype = torch.bfloat16
|
||||
require_grad = True
|
||||
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad)
|
||||
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad)
|
||||
v = torch.randn(B, H, L, 2*D).cuda().to(dtype).requires_grad_(require_grad)
|
||||
w = torch.nn.functional.logsigmoid(torch.randn(B, H, L, D)).cuda().to(dtype).requires_grad_(require_grad)
|
||||
u = (torch.randn(H, D).cuda().to(dtype)).requires_grad_(require_grad)
|
||||
do = torch.rand_like(v).cuda()
|
||||
o2, _ = chunk_rwkv6(q, k, v, w.clone(), u)
|
||||
o, _ = fused_recurrent_rwkv6(q, k, v, w, u, scale=1.0)
|
||||
o.backward(do)
|
||||
dq, q.grad = q.grad.clone(), None
|
||||
dk, k.grad = k.grad.clone(), None
|
||||
dv, v.grad = v.grad.clone(), None
|
||||
dw, w.grad = w.grad.clone(), None
|
||||
du, u.grad = u.grad.clone(), None
|
||||
print((o - o2).abs().max())
|
||||
o2.backward(do)
|
||||
print((o-o2).abs().max())
|
||||
print((q.grad - dq).abs().max())
|
||||
print((k.grad - dk).abs().max())
|
||||
print((v.grad - dv).abs().max())
|
||||
print((w.grad - dw).abs().max())
|
||||
print((u.grad - du).abs().max())
|
||||
378
finetune/lora/v6/fla/ops/rwkv6/recurrent_fuse.py
vendored
Normal file
378
finetune/lora/v6/fla/ops/rwkv6/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,378 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2024, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.ops.utils import chunk_reversed_cumsum_fwd
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_rwkv6_fwd_kernel(
|
||||
q, # query [B, H, T, K]
|
||||
k, # key [B, H, T, K]
|
||||
v, # value [B, H, T, V]
|
||||
w, # log gate [B, H, T, K]
|
||||
u, # bonus [B, H, K]
|
||||
o, # output [B, H, T, V]
|
||||
# initial hidden state initialization [B, H, K, V]
|
||||
h0,
|
||||
ht, # final hidden state [B, H, K, V]
|
||||
s_k_h, # stride size: T * K
|
||||
s_v_h, # stride size: T * V
|
||||
scale, # K ** -0.5
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
|
||||
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
|
||||
|
||||
mask_bk = (i_k * BK + tl.arange(0, BK)) < K
|
||||
mask_bv = (i_v * BV + tl.arange(0, BV)) < V
|
||||
mask_kv = mask_bv[:, None] & mask_bk[None, :]
|
||||
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
|
||||
for _ in range(0, T):
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_w = tl.exp(b_w)
|
||||
b_kv = b_k[None, :] * b_v[:, None]
|
||||
b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]
|
||||
b_o = tl.sum(b_o, axis=1)
|
||||
b_h = b_h * b_w[None, :]
|
||||
b_h += b_kv
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
|
||||
p_q += -K if REVERSE else K
|
||||
p_k += -K if REVERSE else K
|
||||
p_o += -V if REVERSE else V
|
||||
p_v += -V if REVERSE else V
|
||||
p_w += -K if REVERSE else K
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_recurrent_rwkv6_bwd_kernel_dq(
|
||||
# B: B, H: H, T: T, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
k, # key [B, H, T, V]
|
||||
v, # value [B, H, T, V]
|
||||
w, # log gate [B, H, T, K]
|
||||
u, # bonus [B, H, K]
|
||||
|
||||
do, # gradient of output [B, H, T, V]
|
||||
dq, # gradient of query [NV, B, H, T, K]
|
||||
dq_aux, # gradient of query_aux [NV, B, H, T, K]
|
||||
|
||||
# initial hidden state initialization [B, H, K, V]
|
||||
h0,
|
||||
|
||||
s_k_h, # stride size: T * K
|
||||
s_v_h, # stride size: T * V
|
||||
|
||||
scale, # K ** -0.5
|
||||
B: tl.constexpr, # B
|
||||
H: tl.constexpr, # H
|
||||
T: tl.constexpr, # T
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
K: tl.constexpr, # K
|
||||
V: tl.constexpr, # V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_dq_aux = dq_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
|
||||
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < K
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < V
|
||||
mask_kv = mask_bv[:, None] & mask_bk[None, :]
|
||||
b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_kv = b_k[None, :] * b_v[:, None]
|
||||
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_w = tl.exp(b_w)
|
||||
h_q = b_h * b_do[:, None]
|
||||
b_dq = tl.sum(h_q + b_kv * b_u[None, :] * b_do[:, None], axis=0)
|
||||
b_dq *= scale
|
||||
b_dq_aux = tl.sum(h_q, axis=0)
|
||||
b_h = b_h * b_w[None, :]
|
||||
b_h += b_kv
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dq_aux, b_dq_aux.to(p_dq_aux.dtype.element_ty), mask=mask_bk)
|
||||
p_k += -K if REVERSE else K
|
||||
p_do += -V if REVERSE else V
|
||||
p_v += -V if REVERSE else V
|
||||
p_w += -K if REVERSE else K
|
||||
p_dq += -K if REVERSE else K
|
||||
p_dq_aux += -K if REVERSE else K
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_rwkv6_bwd_kernel_dkv(
|
||||
# B: B, H: H, T: T, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, T, K]
|
||||
k, # key [B, H, T, V]
|
||||
v, # value [B, H, T, V]
|
||||
w, # log gate [B, H, T, K]
|
||||
u, # bonus [B, H, K]
|
||||
|
||||
do, # gradient of output [B, H, T, V]
|
||||
dk,
|
||||
dk_aux,
|
||||
dv,
|
||||
dh0,
|
||||
|
||||
# initial hidden state initialization [B, H, K, V]
|
||||
s_k_h, # stride size: T * K
|
||||
s_v_h, # stride size: T * V
|
||||
|
||||
scale, # K ** -0.5
|
||||
B, # B
|
||||
H, # H
|
||||
T, # T
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
K: tl.constexpr, # K
|
||||
V: tl.constexpr, # V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_dk_aux = dk_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < K
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < V
|
||||
mask_kv = mask_bk[:, None] & mask_bv[None, :]
|
||||
|
||||
p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
|
||||
b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(T-1, -1, -1):
|
||||
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_dkv = b_q[:, None] * b_do[None, :]
|
||||
b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
|
||||
tl.store(p_dk_aux, b_dk.to(p_dk_aux.dtype.element_ty), mask=mask_bk)
|
||||
b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], axis=1)
|
||||
b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], axis=0)
|
||||
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||||
b_dh *= tl.exp(b_w)[:, None]
|
||||
b_dh += b_dkv
|
||||
|
||||
p_q += K if REVERSE else -K
|
||||
p_k += K if REVERSE else -K
|
||||
p_v += V if REVERSE else -V
|
||||
p_w += K if REVERSE else -K
|
||||
p_do += V if REVERSE else -V
|
||||
p_dk += K if REVERSE else -K
|
||||
p_dk_aux += K if REVERSE else -K
|
||||
p_dv += V if REVERSE else -V
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
|
||||
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
class FusedRecurrentRWKV6Function(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):
|
||||
# alias
|
||||
q = r
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
|
||||
BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V)
|
||||
else:
|
||||
final_state = None
|
||||
|
||||
o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
|
||||
grid = (NV, NK, B * H)
|
||||
fused_recurrent_rwkv6_fwd_kernel[grid](
|
||||
q, k, v, w, u, o, initial_state, final_state,
|
||||
k.stride(1),
|
||||
v.stride(1),
|
||||
scale,
|
||||
B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None,
|
||||
REVERSE=reverse,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
ctx.save_for_backward(q, k, v, w, u, initial_state, o)
|
||||
ctx.scale = scale
|
||||
ctx.reverse = reverse
|
||||
# we do not need the gradient of the final state from the next chunk
|
||||
# similiar to Trunctated BPTT
|
||||
if final_state is not None:
|
||||
final_state = final_state.detach()
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, w, u, initial_state, o = ctx.saved_tensors
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
scale = ctx.scale
|
||||
|
||||
BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
dq = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
|
||||
dq_aux = torch.empty_like(dq)
|
||||
grid = (NV, NK, B * H)
|
||||
|
||||
fused_recurrent_rwkv6_bwd_kernel_dq[grid](
|
||||
k, v, w, u, do, dq, dq_aux, initial_state,
|
||||
q.stride(1),
|
||||
v.stride(1),
|
||||
scale,
|
||||
B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
REVERSE=ctx.reverse,
|
||||
)
|
||||
dq = dq.sum(0).to(q)
|
||||
dq_aux = dq_aux.sum(0)
|
||||
|
||||
BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
|
||||
dk = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
|
||||
dk_aux = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
|
||||
dv = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
|
||||
dh0 = initial_state.new_empty(B, H, K, V) if initial_state is not None else None
|
||||
grid = (NV, NK, B * H)
|
||||
fused_recurrent_rwkv6_bwd_kernel_dkv[grid](
|
||||
q, k, v, w, u, do, dk, dk_aux, dv, dh0,
|
||||
q.stride(1),
|
||||
v.stride(1),
|
||||
scale,
|
||||
B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
REVERSE=ctx.reverse,
|
||||
)
|
||||
dk = dk.sum(0).to(k)
|
||||
dv = dv.sum(0).to(v)
|
||||
dk_aux = dk_aux.sum(0)
|
||||
|
||||
dw = (dq_aux * q * scale)[:, :, 1:] - (dk_aux * k)[:, :, 0:-1]
|
||||
dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
|
||||
dw = chunk_reversed_cumsum_fwd(dw).to(w)
|
||||
|
||||
du = ((do * v).sum(-1)[..., None] * k * q * scale).sum([0, -2]).to(u)
|
||||
return dq, dk, dv, dw, du, None, dh0, None, None
|
||||
|
||||
|
||||
def fused_recurrent_rwkv6(
|
||||
r: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
scale: int = -1,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
causal: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
r (torch.Tensor):
|
||||
reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
|
||||
k (torch.Tensor):
|
||||
keys of shape `(B, H, T, K)`
|
||||
v (torch.Tensor):
|
||||
values of shape `(B, H, T, V)`
|
||||
w (torch.Tensor):
|
||||
data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
|
||||
u (torch.Tensor):
|
||||
bonus of shape `(H, K)`
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RWKV6 attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `(B, H, K, V)`. Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
|
||||
"""
|
||||
if scale == -1:
|
||||
scale = r.shape[-1] ** -0.5
|
||||
o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
102
finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py
vendored
Normal file
102
finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def naive_recurrent_rwkv6(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
scale: Optional[float] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: Optional[bool] = False
|
||||
):
|
||||
orig_dtype = q.dtype
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
|
||||
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
|
||||
o = torch.zeros_like(v)
|
||||
|
||||
if scale is None:
|
||||
scale = K ** -0.5
|
||||
|
||||
if initial_state is not None:
|
||||
h += initial_state
|
||||
|
||||
for i in range(T):
|
||||
q_i = q[:, :, i, :] * scale
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i, :]
|
||||
w_i = w[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
|
||||
o[:, :, i] = o_i.sum(-2)
|
||||
h = h * w_i[..., None] + kv_i
|
||||
ht = h if output_final_state else None
|
||||
return o.to(orig_dtype), ht
|
||||
|
||||
|
||||
def naive_recurrent_rwkv6_bwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
u,
|
||||
o,
|
||||
do,
|
||||
initial_state=None,
|
||||
output_final_state=False
|
||||
):
|
||||
q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
|
||||
dq = torch.zeros_like(q)
|
||||
dq_aux = torch.zeros_like(q)
|
||||
|
||||
if initial_state is not None:
|
||||
h += initial_state
|
||||
|
||||
for i in range(T):
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i]
|
||||
w_i = w[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
h_i = (h + u[None, ..., None] * kv_i)
|
||||
dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
|
||||
dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
|
||||
dq[:, :, i] = dq_i
|
||||
dq_aux[:, :, i] = dq_aux_i
|
||||
h = h * w_i[..., None] + kv_i
|
||||
|
||||
du = torch.zeros_like(u)
|
||||
dh = torch.zeros_like(h)
|
||||
dk = torch.zeros_like(k)
|
||||
dk_aux = torch.zeros_like(k)
|
||||
dv = torch.zeros_like(v)
|
||||
|
||||
for i in range(T - 1, -1, -1):
|
||||
d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i]
|
||||
du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
|
||||
du += du_i
|
||||
dk_i = (dh * v_i[..., None, :]).sum(-1)
|
||||
dk_aux[:, :, i] = dk_i
|
||||
dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
|
||||
dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
|
||||
dv_i += (dh * k_i[..., None]).sum(-2)
|
||||
|
||||
dk[:, :, i] = dk_i
|
||||
dv[:, :, i] = dv_i
|
||||
dh = dh * w[:, :, i, :, None].exp() + d_kv_i
|
||||
|
||||
# dw = q * dq_aux - k * dk_aux
|
||||
dw = torch.zeros_like(w)
|
||||
for i in range(T - 2, -1, -1):
|
||||
dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
|
||||
|
||||
return dq, dk, dv, dw, du
|
||||
5
finetune/lora/v6/fla/ops/simple_gla/README.md
vendored
Normal file
5
finetune/lora/v6/fla/ops/simple_gla/README.md
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
- Simple GLA
|
||||
|
||||
Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA.
|
||||
|
||||
$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.
|
||||
8
finetune/lora/v6/fla/ops/simple_gla/__init__.py
vendored
Normal file
8
finetune/lora/v6/fla/ops/simple_gla/__init__.py
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_simple_gla
|
||||
|
||||
__all__ = [
|
||||
'chunk_simple_gla'
|
||||
]
|
||||
|
||||
415
finetune/lora/v6/fla/ops/simple_gla/chunk.py
vendored
Normal file
415
finetune/lora/v6/fla/ops/simple_gla/chunk.py
vendored
Normal file
@@ -0,0 +1,415 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def normalize_output(q, k, o):
|
||||
k = k.transpose(-2, -1)
|
||||
k = k.cumsum(-1)
|
||||
k = k.transpose(-2, -1)
|
||||
z = (q * k).sum(-1, keepdim=True)
|
||||
return o / (z + 1e-5)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_simple_gla_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V,
|
||||
(K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
for i_t in range(NT):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(
|
||||
v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,
|
||||
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)
|
||||
b_h *= tl.math.exp2(b_g_last)
|
||||
b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))
|
||||
b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(
|
||||
final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_simple_gla_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_s = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(
|
||||
k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,
|
||||
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT]
|
||||
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
b_o += tl.dot(b_q, b_h, allow_tf32=False)
|
||||
b_s += tl.dot(b_q, b_k, allow_tf32=False)
|
||||
|
||||
p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
|
||||
b_g = tl.load(p_g)
|
||||
b_o = b_o * tl.math.exp2(b_g)[:, None]
|
||||
b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])
|
||||
b_s = tl.where(m_s, b_s, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V),
|
||||
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V),
|
||||
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_simple_gla_bwd_kernel_dh(
|
||||
q,
|
||||
g,
|
||||
do,
|
||||
dh,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(
|
||||
q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V,
|
||||
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T +
|
||||
i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)
|
||||
# [BT, V]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))
|
||||
b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_simple_gla_bwd_kernel_dqkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
s_vo_h,
|
||||
s_vo_t,
|
||||
s_vo_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
scale,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
n_bh = tl.num_programs(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T),
|
||||
(s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K),
|
||||
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_s = tl.dot(b_k, b_q, allow_tf32=False)
|
||||
p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
|
||||
b_g = tl.load(p_g)
|
||||
b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)
|
||||
mask = tl.math.exp2(b_g[None, :] - b_g[:, None])
|
||||
mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)
|
||||
b_s = b_s * mask
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(
|
||||
v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t),
|
||||
(i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(
|
||||
do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V),
|
||||
(s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V),
|
||||
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
# [BT, BV]
|
||||
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + \
|
||||
tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
b_dq = b_dq * tl.math.exp2(b_g)[:, None]
|
||||
b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]
|
||||
b_ds = b_ds * tl.trans(mask)
|
||||
b_ds = b_ds.to(b_k.dtype)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
|
||||
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K),
|
||||
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K),
|
||||
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class SimpleGLAFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, g, initial_state, output_final_state):
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT = 64
|
||||
BK, BV = min(64, triton.next_power_of_2(K)), min(
|
||||
64, triton.next_power_of_2(V))
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
scale = K ** -0.5
|
||||
|
||||
BT = 64
|
||||
assert T % BT == 0, 'sequence length must be divisible by BT'
|
||||
g = g.reshape(B, H, -1, BT)
|
||||
g = g.cumsum(-1) * 1.44269504
|
||||
g = g.reshape(B, H, -1)
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)
|
||||
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_simple_gla_fwd_kernel_h[grid](
|
||||
k, v, h, g, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=output_final_state,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
grid = (NV, NT, B * H)
|
||||
o = torch.empty_like(v)
|
||||
chunk_simple_gla_fwd_kernel_o[grid](
|
||||
q, k, v, h, g, o,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, h, g)
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
@contiguous
|
||||
def backward(ctx, do, d_ht=None):
|
||||
q, k, v, h, g = ctx.saved_tensors
|
||||
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT = 64
|
||||
BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(
|
||||
32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
|
||||
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
scale = K ** -0.5
|
||||
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_simple_gla_bwd_kernel_dh[grid](
|
||||
q, g, do, dh,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale,
|
||||
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
grid = (NK, NT, B * H)
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = v.new_empty(NK, *v.shape)
|
||||
num_stages = 1
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
chunk_simple_gla_bwd_kernel_dqkv[grid](
|
||||
q, k, v, h, g, do, dh, dq, dk, dv,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
dh.stride(1), dh.stride(2),
|
||||
scale,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dv = dv.sum(0)
|
||||
dg = (dq * q - dk * k).sum(-1)
|
||||
|
||||
def rev_cumsum(x):
|
||||
cumsum_x = x.cumsum(-1)
|
||||
rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x
|
||||
return rev_cumsum_x + x
|
||||
dg = rev_cumsum(dg)
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None
|
||||
|
||||
|
||||
def chunk_simple_gla(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor, # log decay
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
g = g.float()
|
||||
o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
52
finetune/lora/v6/fla/ops/simple_gla/naive.py
vendored
Normal file
52
finetune/lora/v6/fla/ops/simple_gla/naive.py
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def torch_simple_gla(q, k, v, g, chunk_size=64):
|
||||
q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5)
|
||||
k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size)
|
||||
g = g.cumsum(-1)
|
||||
kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None])
|
||||
S = torch.zeros_like(kv)
|
||||
|
||||
for i in range(1, g.shape[-2]):
|
||||
S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1]
|
||||
|
||||
inter = (q * g[..., None].exp()) @ S
|
||||
attn = q @ k.transpose(-1, -2)
|
||||
attn = attn * (g[..., None] - g[..., None, :]).exp()
|
||||
attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
|
||||
intra = attn @ v
|
||||
o = inter + intra
|
||||
return rearrange(o, 'b h n c d -> b h (n c) d')
|
||||
|
||||
|
||||
def torch_simple_gla_recurrent(q, k, v, g, chunk_size=64):
|
||||
# q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5)
|
||||
# k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
# v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
|
||||
# g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size)
|
||||
# g = g.cumsum(-1)
|
||||
# kv = k.transpose(-1, -2) @ v
|
||||
|
||||
B, H, T, DK = q.shape
|
||||
q = q * (DK ** -0.5)
|
||||
_, _, _, DV = v.shape
|
||||
S = torch.zeros(B, H, DK, DV).to(q)
|
||||
o = torch.zeros(B, H, T, DV).to(q)
|
||||
for i in range(T):
|
||||
gate = g[:, :, i].exp()
|
||||
key = k[:, :, i]
|
||||
value = v[:, :, i]
|
||||
kv = key.unsqueeze(-1) * value.unsqueeze(-2)
|
||||
S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv
|
||||
q_i = q[:, :, i, :]
|
||||
o_i = (q_i.unsqueeze(-1) * S).sum(-2)
|
||||
o[:, :, i] = o_i
|
||||
|
||||
return o
|
||||
|
||||
579
finetune/lora/v6/fla/ops/utils.py
vendored
Normal file
579
finetune/lora/v6/fla/ops/utils.py
vendored
Normal file
@@ -0,0 +1,579 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BT': 16}, num_warps=2),
|
||||
triton.Config({'BT': 16}, num_warps=4),
|
||||
triton.Config({'BT': 16}, num_warps=8),
|
||||
triton.Config({'BT': 32}, num_warps=2),
|
||||
triton.Config({'BT': 32}, num_warps=4),
|
||||
triton.Config({'BT': 32}, num_warps=8),
|
||||
triton.Config({'BT': 64}, num_warps=2),
|
||||
triton.Config({'BT': 64}, num_warps=4),
|
||||
triton.Config({'BT': 64}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def logcumsumexp_fwd_kernel(
|
||||
s,
|
||||
z,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr
|
||||
):
|
||||
i_bh = tl.program_id(0)
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
b_mp = tl.full([S,], float('-inf'), dtype=tl.float32)
|
||||
b_zp = tl.zeros([S,], dtype=tl.float32)
|
||||
for i_t in range(tl.cdiv(T, BT)):
|
||||
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
|
||||
p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
|
||||
|
||||
# [BT, S]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
# [S,]
|
||||
b_mc = tl.max(b_s, 0)
|
||||
# workaround for compiler bugs
|
||||
if i_t > 0:
|
||||
b_mc = tl.maximum(b_mp, b_mc)
|
||||
b_zp = b_zp * tl.exp(b_mp - b_mc)
|
||||
# [BT, S]
|
||||
b_s = tl.exp(b_s - b_mc)
|
||||
b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp
|
||||
# [S,]
|
||||
b_zc = tl.max(b_z, 0)
|
||||
b_mp = b_mc
|
||||
b_zp = b_zc
|
||||
# [BT, BS]
|
||||
# small eps to prevent underflows
|
||||
b_z = tl.log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc
|
||||
tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def softmax_fwd_kernel(
|
||||
s,
|
||||
p,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
|
||||
p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
|
||||
|
||||
# [BT, S]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
# [BT]
|
||||
b_m = tl.max(b_s, 1)
|
||||
|
||||
# [BT, BS]
|
||||
b_s = tl.exp(b_s - b_m[:, None])
|
||||
b_z = tl.sum(b_s, 1)
|
||||
b_p = tl.where(b_s != 0, b_s / b_z[:, None], 0.)
|
||||
tl.store(p_p, b_p.to(p_p.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def softmax_bwd_kernel(
|
||||
p,
|
||||
dp,
|
||||
ds,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
|
||||
p_dp = tl.make_block_ptr(dp + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
|
||||
p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0))
|
||||
# [BT, BS]
|
||||
b_p = tl.load(p_p, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_dp = tl.load(p_dp, boundary_check=(0, 1)).to(tl.float32)
|
||||
# [BT,]
|
||||
b_pp = tl.sum(b_p * b_dp, 1)
|
||||
# [BT, BS]
|
||||
b_ds = b_p * b_dp - b_p * b_pp[:, None]
|
||||
tl.store(p_ds, b_ds.to(p_ds.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BS': 32}, num_warps=2),
|
||||
triton.Config({'BS': 32}, num_warps=4),
|
||||
triton.Config({'BS': 32}, num_warps=8),
|
||||
triton.Config({'BS': 64}, num_warps=2),
|
||||
triton.Config({'BS': 64}, num_warps=4),
|
||||
triton.Config({'BS': 64}, num_warps=8),
|
||||
triton.Config({'BS': 128}, num_warps=2),
|
||||
triton.Config({'BS': 128}, num_warps=4),
|
||||
triton.Config({'BS': 128}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def recurrent_cumsum_fwd_kernel(
|
||||
s,
|
||||
z,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
o_s = i_s * BS + tl.arange(0, BS)
|
||||
mask = o_s < S
|
||||
|
||||
b_z = tl.zeros([BS], dtype=tl.float32)
|
||||
for i_t in range(0, T):
|
||||
# [BS]
|
||||
b_s = tl.load(s + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32)
|
||||
b_z = b_z + b_s
|
||||
|
||||
tl.store(z + i_bh * s_s_h + i_t * s_s_t + o_s, b_z.to(s.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BS': 32}, num_warps=2),
|
||||
triton.Config({'BS': 32}, num_warps=4),
|
||||
triton.Config({'BS': 32}, num_warps=8),
|
||||
triton.Config({'BS': 64}, num_warps=2),
|
||||
triton.Config({'BS': 64}, num_warps=4),
|
||||
triton.Config({'BS': 64}, num_warps=8),
|
||||
triton.Config({'BS': 128}, num_warps=2),
|
||||
triton.Config({'BS': 128}, num_warps=4),
|
||||
triton.Config({'BS': 128}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def recurrent_cumsum_bwd_kernel(
|
||||
ds,
|
||||
dz,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
o_s = i_s * BS + tl.arange(0, BS)
|
||||
mask = o_s < S
|
||||
|
||||
b_ds = tl.zeros([BS], dtype=tl.float32)
|
||||
for i_t in range(T - 1, -1, -1):
|
||||
# [BS]
|
||||
b_dz = tl.load(dz + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32)
|
||||
b_ds = b_ds + b_dz
|
||||
|
||||
tl.store(ds + i_bh * s_s_h + i_t * s_s_t + o_s, b_ds.to(ds.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BT': 16}, num_warps=2),
|
||||
triton.Config({'BT': 16}, num_warps=4),
|
||||
triton.Config({'BT': 16}, num_warps=8),
|
||||
triton.Config({'BT': 32}, num_warps=2),
|
||||
triton.Config({'BT': 32}, num_warps=4),
|
||||
triton.Config({'BT': 32}, num_warps=8),
|
||||
triton.Config({'BT': 64}, num_warps=2),
|
||||
triton.Config({'BT': 64}, num_warps=4),
|
||||
triton.Config({'BT': 64}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_cumsum_fwd_kernel(
|
||||
s,
|
||||
z,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
b_z = tl.zeros([BS], dtype=tl.float32)
|
||||
for i_t in range(tl.cdiv(T, BT)):
|
||||
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if i_t >= 0:
|
||||
b_z += tl.sum(b_s, 0)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BT': 16}, num_warps=2),
|
||||
triton.Config({'BT': 16}, num_warps=4),
|
||||
triton.Config({'BT': 16}, num_warps=8),
|
||||
triton.Config({'BT': 32}, num_warps=2),
|
||||
triton.Config({'BT': 32}, num_warps=4),
|
||||
triton.Config({'BT': 32}, num_warps=8),
|
||||
triton.Config({'BT': 64}, num_warps=2),
|
||||
triton.Config({'BT': 64}, num_warps=4),
|
||||
triton.Config({'BT': 64}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_cumsum_bwd_kernel(
|
||||
ds,
|
||||
dz,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
||||
|
||||
b_ds = tl.zeros([BS], dtype=tl.float32)
|
||||
for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
|
||||
p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False)
|
||||
tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if i_t >= 0:
|
||||
b_ds += tl.sum(b_dz, 0)
|
||||
|
||||
|
||||
@contiguous
|
||||
def chunk_cumsum_fwd(
|
||||
s: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
B, H, T, S = s.shape
|
||||
BS = 32
|
||||
|
||||
dtype = dtype or s.dtype
|
||||
grid = (triton.cdiv(S, BS), B * H)
|
||||
z = torch.empty_like(s, dtype=dtype)
|
||||
chunk_cumsum_fwd_kernel[grid](
|
||||
s, z,
|
||||
s.stride(1), s.stride(2), s.stride(3),
|
||||
T=T, S=S, BS=BS
|
||||
)
|
||||
return z
|
||||
|
||||
|
||||
@contiguous
|
||||
def chunk_cumsum_bwd(
|
||||
dz: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
B, H, T, S = dz.shape
|
||||
BS = 32
|
||||
|
||||
dtype = dtype or dz.dtype
|
||||
grid = (triton.cdiv(S, BS), B * H)
|
||||
ds = torch.empty_like(dz, dtype=dtype)
|
||||
chunk_cumsum_bwd_kernel[grid](
|
||||
ds, dz,
|
||||
ds.stride(1), ds.stride(2), ds.stride(3),
|
||||
T=T, S=S, BS=BS
|
||||
)
|
||||
return ds
|
||||
|
||||
|
||||
class CumsumFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, s, dtype):
|
||||
z = chunk_cumsum_fwd(s, dtype)
|
||||
ctx.dtype = dtype
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz):
|
||||
ds = chunk_cumsum_bwd(dz, ctx.dtype)
|
||||
return ds, None
|
||||
|
||||
|
||||
def cumsum(
|
||||
s: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
return CumsumFunction.apply(s, dtype)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BS': 32}, num_warps=2),
|
||||
triton.Config({'BS': 32}, num_warps=4),
|
||||
triton.Config({'BS': 32}, num_warps=8),
|
||||
triton.Config({'BS': 64}, num_warps=2),
|
||||
triton.Config({'BS': 64}, num_warps=4),
|
||||
triton.Config({'BS': 64}, num_warps=8),
|
||||
triton.Config({'BS': 128}, num_warps=2),
|
||||
triton.Config({'BS': 128}, num_warps=4),
|
||||
triton.Config({'BS': 128}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def recurrent_reversed_cumsum_fwd_kernel(
|
||||
s,
|
||||
z,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
o_s = i_s * BS + tl.arange(0, BS)
|
||||
mask = o_s < S
|
||||
|
||||
b_z = tl.zeros([BS], dtype=tl.float32)
|
||||
for i_t in range(T - 1, -1, -1):
|
||||
# [BS]
|
||||
b_s = tl.load(s + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32)
|
||||
b_z = b_z + b_s
|
||||
|
||||
tl.store(z + i_bh * s_s_h + i_t * s_s_t + o_s, b_z.to(s.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BS': 32}, num_warps=2),
|
||||
triton.Config({'BS': 32}, num_warps=4),
|
||||
triton.Config({'BS': 32}, num_warps=8),
|
||||
triton.Config({'BS': 64}, num_warps=2),
|
||||
triton.Config({'BS': 64}, num_warps=4),
|
||||
triton.Config({'BS': 64}, num_warps=8),
|
||||
triton.Config({'BS': 128}, num_warps=2),
|
||||
triton.Config({'BS': 128}, num_warps=4),
|
||||
triton.Config({'BS': 128}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def recurrent_reversed_cumsum_bwd_kernel(
|
||||
ds,
|
||||
dz,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
o_s = i_s * BS + tl.arange(0, BS)
|
||||
mask = o_s < S
|
||||
|
||||
b_ds = tl.zeros([BS], dtype=tl.float32)
|
||||
for i_t in range(0, T):
|
||||
# [BS]
|
||||
b_dz = tl.load(dz + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32)
|
||||
b_ds = b_ds + b_dz
|
||||
|
||||
tl.store(ds + i_bh * s_s_h + i_t * s_s_t + o_s, b_ds.to(ds.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BT': 16}, num_warps=2),
|
||||
triton.Config({'BT': 16}, num_warps=4),
|
||||
triton.Config({'BT': 16}, num_warps=8),
|
||||
triton.Config({'BT': 32}, num_warps=2),
|
||||
triton.Config({'BT': 32}, num_warps=4),
|
||||
triton.Config({'BT': 32}, num_warps=8),
|
||||
triton.Config({'BT': 64}, num_warps=2),
|
||||
triton.Config({'BT': 64}, num_warps=4),
|
||||
triton.Config({'BT': 64}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_reversed_cumsum_fwd_kernel(
|
||||
s,
|
||||
z,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
||||
|
||||
b_z = tl.zeros([BS], dtype=tl.float32)
|
||||
for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
|
||||
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if i_t >= 0:
|
||||
b_z += tl.sum(b_s, 0)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BT': 16}, num_warps=2),
|
||||
triton.Config({'BT': 16}, num_warps=4),
|
||||
triton.Config({'BT': 16}, num_warps=8),
|
||||
triton.Config({'BT': 32}, num_warps=2),
|
||||
triton.Config({'BT': 32}, num_warps=4),
|
||||
triton.Config({'BT': 32}, num_warps=8),
|
||||
triton.Config({'BT': 64}, num_warps=2),
|
||||
triton.Config({'BT': 64}, num_warps=4),
|
||||
triton.Config({'BT': 64}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_reversed_cumsum_bwd_kernel(
|
||||
ds,
|
||||
dz,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
b_ds = tl.zeros([BS], dtype=tl.float32)
|
||||
for i_t in range(tl.cdiv(T, BT)):
|
||||
p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False)
|
||||
tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
if i_t >= 0:
|
||||
b_ds += tl.sum(b_dz, 0)
|
||||
|
||||
|
||||
@contiguous
|
||||
def chunk_reversed_cumsum_fwd(
|
||||
s: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
B, H, T, S = s.shape
|
||||
BS = 32
|
||||
|
||||
dtype = dtype or s.dtype
|
||||
grid = (triton.cdiv(S, BS), B * H)
|
||||
z = torch.empty_like(s, dtype=dtype)
|
||||
chunk_reversed_cumsum_fwd_kernel[grid](
|
||||
s, z,
|
||||
s.stride(1), s.stride(2), s.stride(3),
|
||||
T=T, S=S, BS=BS
|
||||
)
|
||||
return z
|
||||
|
||||
|
||||
@contiguous
|
||||
def chunk_reversed_cumsum_bwd(
|
||||
dz: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
B, H, T, S = dz.shape
|
||||
BS = 32
|
||||
|
||||
dtype = dtype or dz.dtype
|
||||
grid = (triton.cdiv(S, BS), B * H)
|
||||
ds = torch.empty_like(dz, dtype=dtype)
|
||||
chunk_reversed_cumsum_bwd_kernel[grid](
|
||||
ds, dz,
|
||||
ds.stride(1), ds.stride(2), ds.stride(3),
|
||||
T=T, S=S, BS=BS
|
||||
)
|
||||
return ds
|
||||
|
||||
|
||||
class ReversedCumsumFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, s, dtype):
|
||||
z = chunk_reversed_cumsum_fwd(s, dtype)
|
||||
ctx.dtype = dtype
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz):
|
||||
ds = chunk_reversed_cumsum_bwd(dz, ctx.dtype)
|
||||
return ds, None
|
||||
|
||||
|
||||
def reversed_cumsum(
|
||||
s: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
return CumsumFunction.apply(s, dtype)
|
||||
Reference in New Issue
Block a user