This commit is contained in:
11
finetune/lora/v6/fla/ops/gla/__init__.py
vendored
Normal file
11
finetune/lora/v6/fla/ops/gla/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_gla
|
||||
from .chunk_fuse import fused_chunk_gla
|
||||
from .recurrent_fuse import fused_recurrent_gla
|
||||
|
||||
__all__ = [
|
||||
'chunk_gla',
|
||||
'fused_chunk_gla',
|
||||
'fused_recurrent_gla'
|
||||
]
|
||||
734
finetune/lora/v6/fla/ops/gla/chunk.py
vendored
Normal file
734
finetune/lora/v6/fla/ops/gla/chunk.py
vendored
Normal file
@@ -0,0 +1,734 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.ops.utils import chunk_reversed_cumsum_fwd
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BS': 16}, num_warps=2),
|
||||
triton.Config({'BS': 16}, num_warps=4),
|
||||
triton.Config({'BS': 16}, num_warps=8),
|
||||
triton.Config({'BS': 32}, num_warps=2),
|
||||
triton.Config({'BS': 32}, num_warps=4),
|
||||
triton.Config({'BS': 32}, num_warps=8),
|
||||
triton.Config({'BS': 64}, num_warps=2),
|
||||
triton.Config({'BS': 64}, num_warps=4),
|
||||
triton.Config({'BS': 64}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_gla_fwd_kernel_cum(
|
||||
s,
|
||||
o,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
for i_t in range(NT):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
if i_t < NT - 1:
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
else:
|
||||
b_gn = tl.min(b_g, axis=1)
|
||||
b_h *= tl.exp(b_gn)[:, None]
|
||||
b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
|
||||
b_h += tl.dot(b_k, b_v, allow_tf32=False)
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_fwd_kernel_intra(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
A,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
NC: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
|
||||
n_bh = tl.num_programs(2)
|
||||
|
||||
if i_i > i_j:
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype)
|
||||
# [BK, BC]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
|
||||
# [BC, BC]
|
||||
b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
|
||||
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
|
||||
elif i_i == i_j:
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
|
||||
o_i = tl.arange(0, BC)
|
||||
o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
|
||||
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
|
||||
for j in range(0, BC):
|
||||
# [BK,]
|
||||
b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
|
||||
b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
|
||||
# [BC,]
|
||||
b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1)
|
||||
b_A = tl.where(o_i >= j, b_A, 0.)
|
||||
tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
|
||||
|
||||
p_k = tl.advance(p_k, (K,))
|
||||
p_gk = tl.advance(p_gk, (K,))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_fwd_kernel_inter(
|
||||
q,
|
||||
v,
|
||||
g,
|
||||
h,
|
||||
o,
|
||||
A,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# works but dkw, owing to divine benevolence
|
||||
# [BT, BV]
|
||||
if i_k >= 0:
|
||||
b_o += tl.dot(b_qg, b_h, allow_tf32=False)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_o += tl.dot(b_A, b_v, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_bwd_kernel_dh(
|
||||
q,
|
||||
g,
|
||||
do,
|
||||
dh,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BK, BV]
|
||||
b_dh *= tl.exp(b_gn)[:, None]
|
||||
# [BK, BT]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
b_q = (b_q * tl.exp(b_g)).to(b_q.dtype)
|
||||
|
||||
# [BK, BV]
|
||||
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_bwd_kernel_inter(
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
A,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
dA,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
n_bh = tl.num_programs(2)
|
||||
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
|
||||
|
||||
# [BT, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
|
||||
b_k = (b_k * b_gn).to(b_k.dtype)
|
||||
# [BT, BT]
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BV]
|
||||
b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
|
||||
if i_k == 0:
|
||||
b_dv += tl.dot(b_A, b_do, allow_tf32=False)
|
||||
b_do = (b_do * scale).to(b_do.dtype)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
b_dq = b_dq * tl.exp(b_gk)
|
||||
b_dk = b_dk * b_gn
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] >= o_i[None, :]
|
||||
# [BT, BT]
|
||||
b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
|
||||
if i_k == 0:
|
||||
tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_gla_bwd_kernel_intra(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
dA,
|
||||
dq,
|
||||
dk,
|
||||
dg,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
NC: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_t, i_i = i_c // NC, i_c % NC
|
||||
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BC, BK]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
b_dq = tl.zeros([BC, BK], dtype=tl.float32)
|
||||
for i_j in range(0, i_i):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||
# [BC, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
|
||||
# [BC, BC]
|
||||
b_dA = tl.load(p_dA, boundary_check=(0, 1))
|
||||
# [BC, BK]
|
||||
b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
|
||||
b_dq *= tl.exp(b_g - b_gn[None, :])
|
||||
|
||||
o_i = tl.arange(0, BC)
|
||||
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
|
||||
m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
|
||||
for j in range(0, BC):
|
||||
p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
|
||||
p_gkj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BC,]
|
||||
b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
|
||||
# [BK,]
|
||||
b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
|
||||
b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32)
|
||||
# [BC, BK]
|
||||
m_i = o_i[:, None] >= j
|
||||
# [BC, BK]
|
||||
b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.)
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
|
||||
b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
tl.debug_barrier()
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BC, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_dk = tl.zeros([BC, BK], dtype=tl.float32)
|
||||
for i_j in range(i_i + 1, NC):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype)
|
||||
# [BC, BC]
|
||||
b_dA = tl.load(p_dA, boundary_check=(0, 1))
|
||||
# [BC, BK]
|
||||
b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
|
||||
b_dk *= tl.exp(b_gn[None, :] - b_gk)
|
||||
|
||||
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
|
||||
for j in range(0, BC):
|
||||
p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
|
||||
p_gqj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BC,]
|
||||
b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
|
||||
# [BK,]
|
||||
b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
|
||||
b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
|
||||
# [BC, BK]
|
||||
m_i = o_i[:, None] <= j
|
||||
b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
|
||||
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
|
||||
b_dg = b_q * b_dq - b_k * b_dk
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class ChunkGLAFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level):
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT, BC = 64, 16
|
||||
BK = min(64, triton.next_power_of_2(K))
|
||||
BV = min(64, triton.next_power_of_2(V))
|
||||
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
|
||||
NK = triton.cdiv(K, BK)
|
||||
NV = triton.cdiv(V, BV)
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
num_stages = 1
|
||||
|
||||
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NV, NK, B * H)
|
||||
chunk_gla_fwd_kernel_h[grid](
|
||||
k, v, g, h, h0, ht,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
STORE_FINAL_STATE=ht is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return h
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float)
|
||||
|
||||
g_org, g = g, torch.empty_like(g, dtype=torch.float)
|
||||
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
|
||||
# keep cummulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_gla_fwd_kernel_cum[grid](
|
||||
g_org, g,
|
||||
g.stride(1), g.stride(2), g.stride(3),
|
||||
T=T, S=K, BT=BT
|
||||
)
|
||||
h = fwd_inner(
|
||||
q=q, k=k, v=v, g=g,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
h0=initial_state if initial_state is not None else None,
|
||||
ht=final_state if final_state is not None else None
|
||||
)
|
||||
A = q.new_zeros(NK, B, H, T, BT)
|
||||
grid = (NK, NT * NC * NC, B * H)
|
||||
chunk_gla_fwd_kernel_intra[grid](
|
||||
q, k, g, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
scale,
|
||||
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
A = A.sum(0, dtype=A.dtype)
|
||||
o = torch.empty_like(v)
|
||||
grid = (NV, NT, B * H)
|
||||
chunk_gla_fwd_kernel_inter[grid](
|
||||
q, v, g, h, o, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
if checkpoint_level >= 1:
|
||||
del g
|
||||
g = g_org
|
||||
if checkpoint_level > 1:
|
||||
del h
|
||||
h, initial_state = None, None
|
||||
|
||||
ctx.save_for_backward(q, k, v, g, h, initial_state, A)
|
||||
ctx.BT = BT
|
||||
ctx.scale = scale
|
||||
ctx.checkpoint_level = checkpoint_level
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, dht=None):
|
||||
q, k, v, g, h, initial_state, A = ctx.saved_tensors
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT, BC = ctx.BT, 16
|
||||
BK = min(64, triton.next_power_of_2(K))
|
||||
BV = min(64, triton.next_power_of_2(V))
|
||||
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
|
||||
NK = triton.cdiv(K, BK)
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
num_stages = 1
|
||||
|
||||
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NV, NK, B * H)
|
||||
chunk_gla_fwd_kernel_h[grid](
|
||||
k, v, g, h, h0, ht,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
STORE_FINAL_STATE=ht is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return h
|
||||
|
||||
def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, NT, scale):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_gla_bwd_kernel_dh[grid](
|
||||
q, g, do, dh,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
do.stride(1), do.stride(2), do.stride(3),
|
||||
dh.stride(1), dh.stride(2), dh.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return dh
|
||||
|
||||
if ctx.checkpoint_level >= 1:
|
||||
# save the original g and compute its fp32 cumsum during the backward pass for memory consideration
|
||||
g_org, g = g, torch.zeros_like(g, dtype=torch.float)
|
||||
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
|
||||
# keep cummulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_gla_fwd_kernel_cum[grid](
|
||||
g_org, g,
|
||||
g.stride(1), g.stride(2), g.stride(3),
|
||||
T=T, S=K, BT=BT
|
||||
)
|
||||
|
||||
# rerun the forward pass to get h if checkpoint_level >= 1
|
||||
if ctx.checkpoint_level > 1:
|
||||
h = fwd_inner(
|
||||
q=q, k=k, v=v, g=g,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
h0=initial_state if initial_state is not None else None,
|
||||
ht=None
|
||||
)
|
||||
|
||||
scale = ctx.scale
|
||||
dh = bwd_inner(
|
||||
q, g, do,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
scale=scale
|
||||
)
|
||||
dq = torch.empty_like(q, dtype=torch.float)
|
||||
dk = torch.empty_like(k, dtype=torch.float)
|
||||
dg = torch.empty_like(k, dtype=torch.float)
|
||||
dv = v.new_empty(NK, *v.shape)
|
||||
dA = q.new_zeros(B, H, T, BT)
|
||||
grid = (NK, NT, B * H)
|
||||
chunk_gla_bwd_kernel_inter[grid](
|
||||
k, v, h, g, A, do, dh, dq, dk, dv, dA,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dv = dv.sum(0, dtype=dv.dtype)
|
||||
grid = (NK, NT * NC, B * H)
|
||||
chunk_gla_bwd_kernel_intra[grid](
|
||||
q, k, g, dA, dq, dk, dg,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
dq = dq.to(q.dtype)
|
||||
dk = dk.to(q.dtype)
|
||||
# reversed cumsum, equivalent to:
|
||||
#
|
||||
# def reversed_cumsum(x, dim=-1):
|
||||
# c = x.cumsum(dim)
|
||||
# return x + c.index_select(dim, x.new_tensor([c.shape[dim]-1], dtype=torch.long)) - c
|
||||
dg = chunk_reversed_cumsum_fwd(dg).to(k.dtype)
|
||||
return dq, dk, dv, dg, None, None, None, None
|
||||
|
||||
|
||||
def chunk_gla(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
scale: Optional[int] = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
checkpoint_level: Optional[int] = 2
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `(B, H, T, K)`
|
||||
k (torch.Tensor):
|
||||
keys of shape `(B, H, T, K)`
|
||||
v (torch.Tensor):
|
||||
values of shape `(B, H, T, V)`
|
||||
g (torch.Tensor):
|
||||
Forget gates of shape `(B, H, T, K)` applied to keys.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the GLA attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `(B, H, K, V)`. Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
|
||||
checkpoint_level (Optional[int]):
|
||||
Checkpointing level; higher values will save more memories and do more recomputations during backward.
|
||||
Default: `0`:
|
||||
- Level `0`: no memory saved, no recomputation.
|
||||
- Level `1`: recompute the fp32 cumulative values during backward.
|
||||
- Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
|
||||
"""
|
||||
assert checkpoint_level in [0, 1, 2]
|
||||
if scale is None:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level)
|
||||
return o, final_state
|
||||
548
finetune/lora/v6/fla/ops/gla/chunk_fuse.py
vendored
Normal file
548
finetune/lora/v6/fla/ops/gla/chunk_fuse.py
vendored
Normal file
@@ -0,0 +1,548 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Songlin Yang
|
||||
# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.ops.gla.chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum,
|
||||
prepare_qg_kg)
|
||||
from fla.utils import contiguous
|
||||
|
||||
inv_ln2 = 1.44269504
|
||||
|
||||
@triton.jit
|
||||
def fused_chunk_gla_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_K]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
g, # cumulative sum of log decay [B, H, L, D_head_K]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
# make block pointers
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0))
|
||||
p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
|
||||
if CHECK and i == 0:
|
||||
b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
|
||||
b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
|
||||
else:
|
||||
b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False)
|
||||
b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False)
|
||||
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
p_q = tl.advance(p_q, (BT, 0))
|
||||
p_k = tl.advance(p_k, (0, BT))
|
||||
p_v = tl.advance(p_v, (BT, 0))
|
||||
p_o = tl.advance(p_o, (BT, 0))
|
||||
p_db += BT * DK
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_chunk_gla_bwd_kernel(
|
||||
q, k, v, g,
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
|
||||
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
CHECK: tl.constexpr
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
# [BV, BK]
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
for i in range(0, tl.cdiv(T, BT)):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2
|
||||
d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32)
|
||||
|
||||
# [DV, BT]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [DV, DK]
|
||||
if CHECK and i == 0:
|
||||
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
|
||||
b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
|
||||
else:
|
||||
b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
|
||||
b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False)
|
||||
b_dq *= scale
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# sync threads
|
||||
b_h = None
|
||||
tl.debug_barrier()
|
||||
# [BK, BV]
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
# cum = tl.zeros([BK], dtype=tl.float32)
|
||||
for i in range(1, tl.cdiv(T, BT) + 1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK),
|
||||
(s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV),
|
||||
(s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
# [DK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BT, DK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, DV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32)
|
||||
|
||||
# inter-chunk
|
||||
# [DK, DV]
|
||||
if CHECK and i == 1:
|
||||
b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
|
||||
b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
|
||||
b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
|
||||
else:
|
||||
b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False))
|
||||
b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False)
|
||||
b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False)
|
||||
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fwd_inner_chunk(
|
||||
q, k, g, A,
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
):
|
||||
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
|
||||
|
||||
for i in range(BT):
|
||||
_q = tl.load(p_q, mask=mask, other=0) * scale
|
||||
gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
|
||||
s = _q[None, :] * b_k * tl.math.exp2(gq[None, :] - b_g)
|
||||
score = tl.sum(s, axis=1)
|
||||
score = tl.where(o_i <= i, score, 0)
|
||||
tl.store(p_A, score.to(p_A.dtype.element_ty))
|
||||
p_q += DK
|
||||
p_gq += DK
|
||||
p_A += BT
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_inner_chunk(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
dA,
|
||||
dq,
|
||||
dk,
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
# clamp_min, # minimum log value of the gate for numerical stability. default: -5
|
||||
BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
o_i = tl.arange(0, BT)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK)
|
||||
p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT)
|
||||
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
|
||||
for i in range(BT):
|
||||
_q = tl.load(p_q, mask=mask, other=0)
|
||||
gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32)
|
||||
score = tl.math.exp2(gq[None, :] - b_g)
|
||||
score = tl.where(o_i[:, None] <= i, score, 0)
|
||||
_dA = tl.load(p_dA)
|
||||
_dA = tl.where(o_i <= i, _dA, 0)
|
||||
b_dk += (_dA[:, None] * score * _q[None, :])
|
||||
b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0)
|
||||
tl.store(p_dq, b_dq, mask=mask)
|
||||
p_q += DK
|
||||
p_dq += DK
|
||||
p_gq += DK
|
||||
p_dA += BT
|
||||
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class FusedChunkGLAFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
|
||||
ctx.g_dtype = g.dtype
|
||||
g_original = g
|
||||
# cumulative decay should be in float32, otherwise the err will be accumulated and amplified.
|
||||
g = torch.empty_like(g, dtype=torch.float32)
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
ctx.scale = scale
|
||||
|
||||
# inter-chunk
|
||||
BT = 16 # chunk_size
|
||||
BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
|
||||
num_stages = 1
|
||||
num_warps = 2
|
||||
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
q_g = torch.empty_like(q)
|
||||
k_g = torch.empty_like(k)
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
fwd_decay_cumsum[grid](
|
||||
g_original,
|
||||
g,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
|
||||
)
|
||||
prepare_qg_kg[grid](
|
||||
q, k, g, q_g, k_g,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
|
||||
)
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float, requires_grad=False)
|
||||
else:
|
||||
final_state = None
|
||||
# the bug still exists even for Triton 2.2 on H100 GPUs
|
||||
# so we always enable initial checks
|
||||
CHECK = True
|
||||
if version.parse(triton.__version__) < version.parse('2.2.0'):
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Triton<2.2.0 detected for running this kernel, "
|
||||
"which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
|
||||
"that lead to significant precision loss. "
|
||||
"We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
|
||||
"For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
|
||||
)
|
||||
CHECK = True
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_chunk_gla_fwd_kernel[grid](
|
||||
q_g, k_g, v, g, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=output_final_state,
|
||||
CHECK=CHECK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
|
||||
# intra-chunk
|
||||
chunk_size = 16
|
||||
num_chunk = seq_len // chunk_size
|
||||
v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
|
||||
BK = min(d_head_qk, 64)
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
A = q.new_empty(NK, batch_size, n_heads, triton.cdiv(seq_len, BT), BT, BT)
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
fwd_inner_chunk[grid](
|
||||
q, k, g, A,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_stages=3,
|
||||
num_warps=4
|
||||
)
|
||||
A = A.sum(0)
|
||||
o2 = A @ v2
|
||||
o2 = rearrange(o2, 'b h n c d -> b h (n c) d')
|
||||
# combine inner and inter
|
||||
o.add_(o2)
|
||||
ctx.save_for_backward(q, k, v, g_original, A, initial_state)
|
||||
ctx.CHECK = CHECK
|
||||
return o.to(v), final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, g_origin, A, initial_state = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = ctx.scale
|
||||
|
||||
# recomputation
|
||||
# inter-chunk
|
||||
BT = 16 # chunk_size
|
||||
g = torch.empty_like(g_origin, dtype=torch.float32)
|
||||
BK, BV = min(d_head_qk, 64), min(d_head_v, 64)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
q_g = torch.empty_like(q)
|
||||
k_g = torch.empty_like(k)
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
fwd_decay_cumsum[grid](
|
||||
g_origin,
|
||||
g,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
|
||||
)
|
||||
prepare_qg_kg[grid](
|
||||
q, k, g, q_g, k_g,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, BK=BK, DK=d_head_qk, num_warps=1
|
||||
)
|
||||
|
||||
# inter-chunk
|
||||
BT = 16
|
||||
BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 2
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
|
||||
fused_chunk_gla_bwd_kernel[grid](
|
||||
q_g, k_g, v, g, do, dq, dk, dv, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
# clamp_min=-3,
|
||||
BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
CHECK=ctx.CHECK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
|
||||
# intra chunk
|
||||
num_chunk = seq_len // BT
|
||||
v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk)
|
||||
do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk)
|
||||
dA2 = (do2 @ v2.transpose(-2, -1)) * scale
|
||||
dv2 = A.transpose(-1, -2) @ do2
|
||||
dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk)
|
||||
|
||||
BK = min(triton.next_power_of_2(d_head_qk), 16)
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
dk2 = torch.empty_like(k)
|
||||
dq2 = torch.empty_like(q)
|
||||
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
bwd_inner_chunk[grid](
|
||||
q, k, g,
|
||||
dA2, dq2, dk2,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, BK=BK,
|
||||
num_warps=1,
|
||||
num_stages=3
|
||||
)
|
||||
|
||||
BK = min(triton.next_power_of_2(d_head_qk), 32)
|
||||
NK = triton.cdiv(d_head_qk, BK)
|
||||
dg = torch.empty_like(g, dtype=torch.float32)
|
||||
grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads)
|
||||
bwd_decay_global_cumsum[grid](
|
||||
dq2, dq, dk2, dk, q, k, g, dg,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
BT=BT, DK=d_head_qk, BK=BK,
|
||||
num_warps=1,
|
||||
num_stages=1
|
||||
)
|
||||
dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT)
|
||||
|
||||
def rev_cumsum_exclusive(x):
|
||||
cumsum_x = x.cumsum(-2)
|
||||
rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x
|
||||
return rev_cumsum_x
|
||||
|
||||
rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :])
|
||||
dg.add_(rev_cumsum_dg.unsqueeze(-2))
|
||||
dv.add_(dv2)
|
||||
dg = rearrange(dg, 'b h n c d -> b h (n c) d')
|
||||
|
||||
return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None
|
||||
|
||||
|
||||
def pad(x, chunk_size=16):
|
||||
seq_len = x.shape[-2]
|
||||
padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size
|
||||
if x.shape[-2] % chunk_size != 0:
|
||||
x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def ceildiv(a, b):
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def fused_chunk_gla(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
scale: int = -1,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if scale == -1:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
seq_len = q.shape[-2]
|
||||
q, k, v, g = map(lambda x: pad(x), [q, k, v, g])
|
||||
o, final_state = FusedChunkGLAFunction.apply(
|
||||
q, k, v, g, scale, initial_state, output_final_state)
|
||||
o = o[..., :seq_len, :]
|
||||
return o, final_state
|
||||
138
finetune/lora/v6/fla/ops/gla/chunk_util.py
vendored
Normal file
138
finetune/lora/v6/fla/ops/gla/chunk_util.py
vendored
Normal file
@@ -0,0 +1,138 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
inv_ln2 = 1.44269504
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fwd_decay_cumsum(
|
||||
g,
|
||||
g_o,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
B,
|
||||
H,
|
||||
T,
|
||||
scale,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
DK: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
cum_decay = tl.zeros([BK], dtype=tl.float32)
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
|
||||
for i in range(BT):
|
||||
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
cum_decay += _g * inv_ln2
|
||||
tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)
|
||||
p_g += DK
|
||||
p_go += DK
|
||||
|
||||
@triton.jit
|
||||
def prepare_qg_kg(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
qg,
|
||||
kg,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
B,
|
||||
H,
|
||||
T,
|
||||
scale,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
DK: tl.constexpr
|
||||
):
|
||||
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)
|
||||
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
|
||||
last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))
|
||||
|
||||
for i in range(BT):
|
||||
_q = tl.load(p_q, mask=mask, other=0)
|
||||
_k = tl.load(p_k, mask=mask, other=0)
|
||||
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
_q *= tl.math.exp2(_g) * scale
|
||||
_k *= tl.math.exp2(last_decay - _g)
|
||||
tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)
|
||||
tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)
|
||||
p_q += DK
|
||||
p_g += DK
|
||||
p_k += DK
|
||||
p_kg += DK
|
||||
p_qg += DK
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_decay_global_cumsum(
|
||||
dq_inner,
|
||||
dq_inter,
|
||||
dk_inner,
|
||||
dk_inter,
|
||||
q, k, g, dg,
|
||||
s_qk_h,
|
||||
s_qk_t,
|
||||
s_qk_d,
|
||||
B,
|
||||
H,
|
||||
T,
|
||||
scale,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
DK: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK
|
||||
cum_grad_dg = tl.zeros([BK], dtype=tl.float32)
|
||||
mask = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
last_g = tl.zeros([BK], dtype=tl.float32)
|
||||
for j in range(BT-1, -1, -1):
|
||||
_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
if j == (BT-1):
|
||||
last_g = _g
|
||||
_dq1 = tl.load(p_dq_inner, mask=mask, other=0)
|
||||
_dq2 = tl.load(p_dq_inter, mask=mask, other=0)
|
||||
_dq2 *= tl.math.exp2(_g)
|
||||
_dq = _dq1 + _dq2
|
||||
tl.store(p_dq_inter, _dq, mask=mask)
|
||||
_dk1 = tl.load(p_dk_inner, mask=mask, other=0)
|
||||
_dk2 = tl.load(p_dk_inter, mask=mask, other=0)
|
||||
_dk2 *= tl.math.exp2(last_g - _g)
|
||||
_dk = _dk1 + _dk2
|
||||
tl.store(p_dk_inter, _dk, mask=mask)
|
||||
_q = tl.load(p_q, mask=mask, other=0)
|
||||
_k = tl.load(p_k, mask=mask, other=0)
|
||||
_dg = _dq * _q - _dk * _k
|
||||
cum_grad_dg += _dg
|
||||
tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)
|
||||
p_g -= DK
|
||||
p_k -= DK
|
||||
p_q -= DK
|
||||
p_dq_inner -= DK
|
||||
p_dk_inner -= DK
|
||||
p_dq_inter -= DK
|
||||
p_dk_inter -= DK
|
||||
p_dg -= DK
|
||||
|
||||
116
finetune/lora/v6/fla/ops/gla/naive.py
vendored
Normal file
116
finetune/lora/v6/fla/ops/gla/naive.py
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fla.ops.gla.recurrent_fuse import fused_recurrent_gla
|
||||
|
||||
|
||||
def ceildiv(a, b):
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def naive_recurrent_gla(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
gk,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
causal=True
|
||||
):
|
||||
orig_dtype = q.dtype
|
||||
q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk))
|
||||
batch_size, n_heads, seq_len, d_head_k = q.shape
|
||||
_, _, _, d_head_v = v.shape
|
||||
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
|
||||
o = torch.zeros_like(v)
|
||||
scale = d_head_k ** -0.5
|
||||
|
||||
if initial_state is not None:
|
||||
h += initial_state
|
||||
|
||||
for i in range(seq_len):
|
||||
q_i = q[:, :, i, :] * scale
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i, :]
|
||||
gk_i = gk[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
h = h * gk_i[..., None] + kv_i
|
||||
o_i = (q_i[..., None] * h).sum(-2)
|
||||
o[:, :, i] = o_i
|
||||
|
||||
if causal:
|
||||
return o.to(orig_dtype), h
|
||||
else:
|
||||
o_reverse = torch.zeros_like(v)
|
||||
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
|
||||
for i in range(seq_len-1, -1, -1):
|
||||
q_i = q[:, :, i, :] * scale
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i, :]
|
||||
gk_i = gk[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
h = h * gk_i[..., None] + kv_i
|
||||
o_i = (q_i[..., None] * h).sum(-2)
|
||||
o_reverse[:, :, i] = o_i
|
||||
|
||||
return o, o_reverse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
B = 4
|
||||
H = 4
|
||||
L = 512
|
||||
D = 128
|
||||
dtype = torch.float32
|
||||
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
|
||||
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
|
||||
v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True)
|
||||
g = F.logsigmoid(torch.rand(B, H, L, D)).cuda(
|
||||
).clamp_min(-1).to(torch.float32).requires_grad_(True)
|
||||
|
||||
do = torch.rand_like(v).cuda()
|
||||
do2 = torch.rand_like(v).cuda()
|
||||
intial_state = torch.rand(B, H, D, D).cuda()
|
||||
|
||||
ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False)
|
||||
|
||||
ref.backward(do, retain_graph=True)
|
||||
ref_rev.backward(do2, retain_graph=True)
|
||||
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dg, g.grad = g.grad.clone(), None
|
||||
|
||||
tri, tri_rev = fused_recurrent_gla(
|
||||
q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False)
|
||||
tri.backward(do, retain_graph=True)
|
||||
tri_rev.backward(do2, retain_graph=True)
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dg, g.grad = g.grad.clone(), None
|
||||
|
||||
assert ref.allclose(tri, 0, 1e-5), breakpoint()
|
||||
assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint()
|
||||
assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
|
||||
assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
|
||||
assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
|
||||
assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
|
||||
|
||||
# tri = fused_chunk_gla(q, k, v, g)
|
||||
# tri.backward(do, retain_graph=True)
|
||||
# tri_dq, q.grad = q.grad.clone(), None
|
||||
# tri_dk, k.grad = k.grad.clone(), None
|
||||
# tri_dv, v.grad = v.grad.clone(), None
|
||||
# tri_dg, g.grad = g.grad.clone(), None
|
||||
|
||||
# assert ref.allclose(tri, 0, 1e-5), breakpoint()
|
||||
# assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
|
||||
# assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
|
||||
# assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
|
||||
# assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
|
||||
# breakpoint()
|
||||
print("Pass")
|
||||
404
finetune/lora/v6/fla/ops/gla/recurrent_fuse.py
vendored
Normal file
404
finetune/lora/v6/fla/ops/gla/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,404 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
# on-the-fly computation without materializing hidden statets into HBMs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_gla_fwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_K]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
gk, # log gate [B, H, L, D_head_K]
|
||||
gv, # log gate [B, H, L, D_head_V]
|
||||
o, # output [B, H, L, D_head_V]
|
||||
# initial hidden state initialization [B, H, D_head_K, D_head_V]
|
||||
initial_state,
|
||||
final_state, # final hidden state [B, H, D_head_K, D_head_V]
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
USE_GK: tl.constexpr, # whether to use gk
|
||||
USE_GV: tl.constexpr, # whether to use gv
|
||||
):
|
||||
# indices
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
|
||||
if USE_GK:
|
||||
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
if USE_GV:
|
||||
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
|
||||
mask_bk = (i_k * BK + tl.arange(0, BK)) < DK
|
||||
mask_bv = (i_v * BV + tl.arange(0, BV)) < DV
|
||||
|
||||
h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
|
||||
mask_kv = mask_bk[None, :] & mask_bv[:, None]
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
if USE_GK:
|
||||
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
|
||||
h = h * _gk[None, :]
|
||||
if USE_GV:
|
||||
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
|
||||
h = h * _gv[:, None]
|
||||
h += _k[None, :] * _v[:, None]
|
||||
_o = h * _q[None, :]
|
||||
_o = tl.sum(_o, axis=1)
|
||||
tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)
|
||||
p_q += -DK if REVERSE else DK
|
||||
p_k += -DK if REVERSE else DK
|
||||
p_o += -DV if REVERSE else DV
|
||||
p_v += -DV if REVERSE else DV
|
||||
if USE_GK:
|
||||
p_gk += -DK if REVERSE else DK
|
||||
if USE_GV:
|
||||
p_gv += -DV if REVERSE else DV
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_final_s = final_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[None, :]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_recurrent_gla_bwd_kernel(
|
||||
# B: batch_size, H: n_heads, T: seq_len, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, L, D_head_K]
|
||||
k, # key [B, H, L, D_head_V]
|
||||
v, # value [B, H, L, D_head_V]
|
||||
gk, # log gate [B, H, L, D_head_K] \alpha
|
||||
gv, # log gate [B, H, L, D_head_V] \bete
|
||||
|
||||
do, # gradient of output [B, H, L, D_head_V]
|
||||
dq, # gradient of query [NV, B, H, L, D_head_K]
|
||||
dk, # gradient of key [NV, B, H, L, D_head_K]
|
||||
dv, # gradient of value [NK, B, H, L, D_head_V]
|
||||
|
||||
# initial hidden state initialization [B, H, D_head_K, D_head_V]
|
||||
initial_state,
|
||||
|
||||
s_qk_h, # stride size: L * D_head_K
|
||||
s_qk_t, # stride size: D_head_K
|
||||
s_qk_d, # stride size: 1
|
||||
|
||||
s_vo_h, # stride size: L * D_head_V
|
||||
s_vo_t, # stride size: D_head_V
|
||||
s_vo_d, # stride size: 1
|
||||
|
||||
B, # batch_size
|
||||
H, # n_heads
|
||||
T, # seq_len
|
||||
scale, # D_head_K ** -0.5
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
DK: tl.constexpr, # D_head_K
|
||||
DV: tl.constexpr, # D_head_V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
USE_GK: tl.constexpr, # whether to use gk
|
||||
USE_GV: tl.constexpr, # whether to use gv
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
if USE_GK:
|
||||
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0)
|
||||
if USE_GV:
|
||||
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0)
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < DK
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < DV
|
||||
mask_kv = mask_bk[:, None] & mask_bv[None, :]
|
||||
h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_init_s = initial_state + i_bh * DK * DV + \
|
||||
(i_k * BK + tl.arange(0, BK)[:, None]) * \
|
||||
DV + (i_v * BV + tl.arange(0, BV)[None, :])
|
||||
h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for i in range(0, T):
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
if USE_GK:
|
||||
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
|
||||
h = h * _gk[:, None]
|
||||
if USE_GV:
|
||||
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
|
||||
h = h * _gv[None, :]
|
||||
h += _k[:, None] * _v[None, :]
|
||||
_d_q = h * _do[None, :]
|
||||
d_q = tl.sum(_d_q, axis=1) * scale
|
||||
tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
|
||||
|
||||
p_k += -DK if REVERSE else DK
|
||||
p_v += -DV if REVERSE else DV
|
||||
p_q += -DK if REVERSE else DK
|
||||
p_do += -DV if REVERSE else DV
|
||||
p_dq += -DK if REVERSE else DK
|
||||
if USE_GK:
|
||||
p_gk += -DK if REVERSE else DK
|
||||
if USE_GV:
|
||||
p_gv += -DV if REVERSE else DV
|
||||
|
||||
# sync threads
|
||||
tl.debug_barrier()
|
||||
|
||||
p_q = q + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
|
||||
p_k = k + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
|
||||
p_do = do + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
|
||||
p_v = v + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
|
||||
p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \
|
||||
BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
|
||||
p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \
|
||||
BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
|
||||
if USE_GK:
|
||||
p_gk = gk + i_bh * s_qk_h + i_k * BK + \
|
||||
tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0)
|
||||
if USE_GV:
|
||||
p_gv = gv + i_bh * s_vo_h + i_v * BV + \
|
||||
tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0)
|
||||
|
||||
d_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
|
||||
for _ in range(T):
|
||||
_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
d_h += _q[:, None] * _do[None, :]
|
||||
d_k = tl.sum(d_h * _v[None, :], axis=1)
|
||||
d_v = tl.sum(d_h * _k[:, None], axis=0)
|
||||
if USE_GK:
|
||||
_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32)
|
||||
d_h *= _gk[:, None]
|
||||
if USE_GV:
|
||||
_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32)
|
||||
d_h *= _gv[None, :]
|
||||
tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||||
|
||||
p_do += DV if REVERSE else -DV
|
||||
p_q += DK if REVERSE else -DK
|
||||
p_k += DK if REVERSE else -DK
|
||||
p_v += DV if REVERSE else -DV
|
||||
p_dk += DK if REVERSE else -DK
|
||||
p_dv += DV if REVERSE else -DV
|
||||
if USE_GK:
|
||||
p_gk += DK if REVERSE else -DK
|
||||
if USE_GV:
|
||||
p_gv += DV if REVERSE else -DV
|
||||
|
||||
|
||||
class FusedRecurrentGLAFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False):
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
# default scale
|
||||
if scale is None:
|
||||
scale = d_head_qk ** -0.5
|
||||
if gk is not None:
|
||||
gk = gk.float().exp()
|
||||
if gv is not None:
|
||||
gv = gv.float().exp()
|
||||
|
||||
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
o = q.new_empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, dtype=torch.float32)
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
|
||||
else:
|
||||
final_state = None
|
||||
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
fused_recurrent_gla_fwd_kernel[grid](
|
||||
q, k, v, gk, gv, o, initial_state, final_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None,
|
||||
USE_GK=gk is not None,
|
||||
USE_GV=gv is not None,
|
||||
REVERSE=reverse,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
ctx.save_for_backward(q, k, v, gk, gv, initial_state, o)
|
||||
ctx.scale = scale
|
||||
ctx.reverse = reverse
|
||||
# we do not need the gradient of the final state from the next chunk
|
||||
# similiar to Trunctated BPTT
|
||||
if final_state is not None:
|
||||
final_state = final_state.detach()
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, gk, gv, initial_state, o = ctx.saved_tensors
|
||||
batch_size, n_heads, seq_len, d_head_qk = q.shape
|
||||
d_head_v = v.shape[-1]
|
||||
scale = ctx.scale
|
||||
|
||||
BK, BV = min(d_head_qk, 32), min(d_head_v, 32)
|
||||
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
dq = q.new_empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=torch.float32)
|
||||
dk = q.new_empty(NV, batch_size, n_heads, seq_len,
|
||||
d_head_qk, dtype=torch.float32)
|
||||
dv = q.new_empty(NK, batch_size, n_heads, seq_len,
|
||||
d_head_v, dtype=torch.float32)
|
||||
grid = (NV, NK, batch_size * n_heads)
|
||||
|
||||
fused_recurrent_gla_bwd_kernel[grid](
|
||||
q, k, v, gk, gv, do, dq, dk, dv, initial_state,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
batch_size, n_heads, seq_len, scale,
|
||||
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
REVERSE=ctx.reverse,
|
||||
USE_GK=gk is not None,
|
||||
USE_GV=gv is not None
|
||||
)
|
||||
dq = dq.sum(0)
|
||||
dk = dk.sum(0)
|
||||
dv = dv.sum(0)
|
||||
if gk is not None:
|
||||
_dgk = dq * q.float() - dk * k.float()
|
||||
if ctx.reverse:
|
||||
dgk = _dgk.cumsum(-2)
|
||||
else:
|
||||
_dgk_cumsum = _dgk.cumsum(-2)
|
||||
dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum
|
||||
else:
|
||||
dgk = None
|
||||
|
||||
if gv is not None:
|
||||
_dgv = do.float() * o.float() - dv * v.float()
|
||||
if ctx.reverse:
|
||||
dgv = _dgv.cumsum(-2)
|
||||
else:
|
||||
_dgv_cumsum = _dgv.cumsum(-2)
|
||||
dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum
|
||||
else:
|
||||
dgv = None
|
||||
|
||||
return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None
|
||||
|
||||
|
||||
# if scale is None, use d_head_qk ** -0.5 by default. Otherwise specify the scale yourself. e.g. scale = 1.0
|
||||
def fused_recurrent_gla(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
gk: torch.Tensor = None,
|
||||
gv: torch.Tensor = None,
|
||||
scale: int = -1,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
causal: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if scale == -1:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
if causal:
|
||||
o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
else:
|
||||
# do not support initial_state yet. looks very strange for bidirectional modeling
|
||||
assert initial_state is None
|
||||
assert output_final_state is False
|
||||
o, final_state = FusedRecurrentGLAFunction.apply(
|
||||
q, k, v, gk, gv, scale, initial_state, output_final_state, False)
|
||||
o_reversed, final_state = FusedRecurrentGLAFunction.apply(
|
||||
q, k, v, gk, gv, scale, initial_state, output_final_state, True)
|
||||
return [o, o_reversed]
|
||||
Reference in New Issue
Block a user