This commit is contained in:
9
finetune/lora/v6/fla/ops/rwkv6/__init__.py
vendored
Normal file
9
finetune/lora/v6/fla/ops/rwkv6/__init__.py
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_rwkv6
|
||||
from .recurrent_fuse import fused_recurrent_rwkv6
|
||||
|
||||
__all__ = [
|
||||
'chunk_rwkv6',
|
||||
'fused_recurrent_rwkv6'
|
||||
]
|
||||
921
finetune/lora/v6/fla/ops/rwkv6/chunk.py
vendored
Normal file
921
finetune/lora/v6/fla/ops/rwkv6/chunk.py
vendored
Normal file
@@ -0,0 +1,921 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.ops.utils import chunk_reversed_cumsum_fwd
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BS': 16}, num_warps=2),
|
||||
triton.Config({'BS': 16}, num_warps=4),
|
||||
triton.Config({'BS': 16}, num_warps=8),
|
||||
triton.Config({'BS': 32}, num_warps=2),
|
||||
triton.Config({'BS': 32}, num_warps=4),
|
||||
triton.Config({'BS': 32}, num_warps=8),
|
||||
triton.Config({'BS': 64}, num_warps=2),
|
||||
triton.Config({'BS': 64}, num_warps=4),
|
||||
triton.Config({'BS': 64}, num_warps=8),
|
||||
],
|
||||
key=['S']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_rwkv6_fwd_kernel_cum(
|
||||
s,
|
||||
o,
|
||||
o_minus_s,
|
||||
s_s_h,
|
||||
s_s_t,
|
||||
s_s_d,
|
||||
T: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def post_process_grad(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
u,
|
||||
do,
|
||||
dk,
|
||||
dq,
|
||||
du,
|
||||
scale,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
H,
|
||||
T: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_h = i_bh % H
|
||||
|
||||
# Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V)
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0))
|
||||
p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,))
|
||||
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
b_u = tl.load(p_u, boundary_check=(0,))
|
||||
|
||||
b_vdo = tl.sum(b_v * b_do, axis=1)
|
||||
b_du = b_vdo[:, None] * b_k * b_q * scale
|
||||
b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale
|
||||
b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale
|
||||
|
||||
b_dq += tl.load(p_dq, boundary_check=(0, 1))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
b_dk += tl.load(p_dk, boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_fwd_kernel_h(
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
|
||||
for i_t in range(NT):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1))
|
||||
if i_t < NT - 1:
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
else:
|
||||
b_gn = tl.min(b_g, axis=1)
|
||||
b_h *= tl.exp(b_gn)[:, None]
|
||||
b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)
|
||||
b_h += tl.dot(b_k, b_v, allow_tf32=False)
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_fwd_kernel_intra(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
gs,
|
||||
u,
|
||||
A,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
scale,
|
||||
H,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
NC: tl.constexpr,
|
||||
DK: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
|
||||
i_h = i_bh % H
|
||||
n_bh = tl.num_programs(2)
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_q = i_t * BT + i_i * BC
|
||||
m_k = o_k < K
|
||||
|
||||
if i_i > i_j:
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
|
||||
p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||
# [BK,]
|
||||
b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0)
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
b_qg = (b_q * tl.exp(b_gs - b_gn[None, :]) * scale).to(b_q.dtype)
|
||||
# [BK, BC]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype)
|
||||
# [BC, BC]
|
||||
b_A = tl.dot(b_qg, b_kg, allow_tf32=False)
|
||||
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
|
||||
elif i_i == i_j:
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
p_q_self = tl.make_block_ptr(q + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
o_i = tl.arange(0, BC)
|
||||
o_g = i_bh * T * K + (i_t * BT + i_j * BC) * K + o_k
|
||||
o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
|
||||
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
|
||||
p_u = tl.make_block_ptr(u + i_h * DK, (DK,), (1,), (i_k * BK), (BK,), (0,))
|
||||
b_u = tl.load(p_u, boundary_check=(0,))
|
||||
for j in range(0, BC):
|
||||
# [BK,]
|
||||
b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
|
||||
b_gk = tl.load(g + o_g + j * K, mask=(m_k & ((i_t * BT + i_j * BC + j) < T)), other=0).to(tl.float32)
|
||||
# [BC,]
|
||||
b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_gs - b_gk[None, :]) * scale, 1)
|
||||
b_A = tl.where(o_i > j, b_A, 0.)
|
||||
# self
|
||||
b_q_self = tl.load(p_q_self, boundary_check=(0,)).to(tl.float32)
|
||||
A_self = tl.sum(b_q_self * b_k * b_u * scale, axis=0)
|
||||
m_self = tl.arange(0, BC) == j
|
||||
b_A = tl.where(m_self, A_self[None], b_A)
|
||||
tl.store(A + o_A + j, b_A.to(A.dtype.element_ty), mask=m_A)
|
||||
p_k = tl.advance(p_k, (K,))
|
||||
p_q_self = tl.advance(p_q_self, (K,))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_fwd_kernel_inter(
|
||||
q,
|
||||
v,
|
||||
gs,
|
||||
h,
|
||||
o,
|
||||
A,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BK]
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
# [BT, BK]
|
||||
b_qg = (b_q * tl.exp(b_gs)).to(b_q.dtype)
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# works but dkw, owing to divine benevolence
|
||||
# [BT, BV]
|
||||
if i_k >= 0:
|
||||
b_o += tl.dot(b_qg, b_h, allow_tf32=False)
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_o += tl.dot(b_A, b_v, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_bwd_kernel_dh(
|
||||
q,
|
||||
g,
|
||||
gs,
|
||||
do,
|
||||
dh,
|
||||
dh0,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
NT: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr
|
||||
):
|
||||
i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
for i_t in range(NT - 1, -1, -1):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
# [BK, BT]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_q = (b_q * scale).to(b_q.dtype)
|
||||
# [BT, BV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
|
||||
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BK, BV]
|
||||
b_dh *= tl.exp(b_gn)[:, None]
|
||||
# [BK, BT]
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
b_q = (b_q * tl.exp(b_gs)).to(b_q.dtype)
|
||||
|
||||
# [BK, BV]
|
||||
b_dh += tl.dot(b_q, b_do, allow_tf32=False)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_dh0 = tl.make_block_ptr(dh0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_bwd_kernel_inter(
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
gs,
|
||||
A,
|
||||
do,
|
||||
dh,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
dA,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
s_v_h,
|
||||
s_v_t,
|
||||
s_v_d,
|
||||
s_h_h,
|
||||
s_h_t,
|
||||
s_h_d,
|
||||
scale,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr
|
||||
):
|
||||
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
n_bh = tl.num_programs(2)
|
||||
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gq = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
|
||||
|
||||
# [BT, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_gq = tl.load(p_gq, boundary_check=(0, 1))
|
||||
b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk)
|
||||
b_k = (b_k * b_gn).to(b_k.dtype)
|
||||
# [BT, BT]
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
|
||||
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
|
||||
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
|
||||
p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
|
||||
# [BT, BV]
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
# [BT, BV]
|
||||
b_do = tl.load(p_do, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_dh = tl.load(p_dh, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BV]
|
||||
b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
|
||||
if i_k == 0:
|
||||
b_dv += tl.dot(b_A, b_do, allow_tf32=False)
|
||||
b_do = (b_do * scale).to(b_do.dtype)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
||||
# [BT, BT]
|
||||
b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dq += tl.dot(b_do, b_h, allow_tf32=False)
|
||||
# [BT, BK]
|
||||
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
|
||||
|
||||
b_dq = b_dq * tl.exp(b_gq)
|
||||
b_dk = b_dk * b_gn
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_s = o_i[:, None] > o_i[None, :]
|
||||
# [BT, BT]
|
||||
b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
|
||||
if i_k == 0:
|
||||
tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_rwkv6_bwd_kernel_intra(
|
||||
q,
|
||||
k,
|
||||
g,
|
||||
gs,
|
||||
dA,
|
||||
dq,
|
||||
dk,
|
||||
s_k_h,
|
||||
s_k_t,
|
||||
s_k_d,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BC: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
NC: tl.constexpr
|
||||
):
|
||||
i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_t, i_i = i_c // NC, i_c % NC
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_q = i_t * BT + i_i * BC
|
||||
m_k = o_k < K
|
||||
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
# [BK,]
|
||||
b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0)
|
||||
# [BC, BK]
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
b_dq = tl.zeros([BC, BK], dtype=tl.float32)
|
||||
for i_j in range(0, i_i):
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
|
||||
# [BC, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype)
|
||||
# [BC, BC]
|
||||
b_dA = tl.load(p_dA, boundary_check=(0, 1))
|
||||
# [BC, BK]
|
||||
b_dq += tl.dot(b_dA, b_kg, allow_tf32=False)
|
||||
b_dq *= tl.exp(b_gs - b_gn[None, :])
|
||||
|
||||
o_i = tl.arange(0, BC)
|
||||
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
|
||||
m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
|
||||
|
||||
for j in range(0, BC):
|
||||
p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
|
||||
|
||||
# [BC,]
|
||||
b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
|
||||
# [BK,]
|
||||
b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
|
||||
b_gkj = tl.load(g + i_bh * T * K + (o_q + j) * K + o_k, mask=(m_k & ((o_q + j) < T)), other=0)
|
||||
# [BC, BK]
|
||||
m_i = o_i[:, None] > j
|
||||
# [BC, BK]
|
||||
b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_gs - b_gkj[None, :]), 0.)
|
||||
|
||||
p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
|
||||
b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1))
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
tl.debug_barrier()
|
||||
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BK,]
|
||||
b_gn = tl.load(p_gn, boundary_check=(0,))
|
||||
# [BC, BK]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
||||
b_dk = tl.zeros([BC, BK], dtype=tl.float32)
|
||||
for i_j in range(i_i + 1, NC):
|
||||
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
|
||||
# [BC, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
b_gs = tl.load(p_gs, boundary_check=(0, 1))
|
||||
b_qg = (b_q * tl.exp(b_gs - b_gn[None, :])).to(b_q.dtype)
|
||||
# [BC, BC]
|
||||
b_dA = tl.load(p_dA, boundary_check=(0, 1))
|
||||
# [BC, BK]
|
||||
b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False)
|
||||
b_dk *= tl.exp(b_gn[None, :] - b_gk)
|
||||
|
||||
o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
|
||||
for j in range(0, BC):
|
||||
p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
|
||||
p_gqj = tl.make_block_ptr(gs + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
|
||||
# [BC,]
|
||||
b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
|
||||
# [BK,]
|
||||
b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
|
||||
b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32)
|
||||
# [BC, BK]
|
||||
m_i = o_i[:, None] < j
|
||||
b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.)
|
||||
|
||||
p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
|
||||
b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1))
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class ChunkRWKV6Function(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level):
|
||||
q = r # alias
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT, BC = 64, 16
|
||||
BK = min(64, triton.next_power_of_2(K))
|
||||
BV = min(64, triton.next_power_of_2(V))
|
||||
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
|
||||
NK = triton.cdiv(K, BK)
|
||||
NV = triton.cdiv(V, BV)
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
num_stages = 1
|
||||
|
||||
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NV, NK, B * H)
|
||||
chunk_rwkv6_fwd_kernel_h[grid](
|
||||
k, v, g, h, h0, ht,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
STORE_FINAL_STATE=ht is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return h
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V, dtype=torch.float)
|
||||
|
||||
g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)
|
||||
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
|
||||
# keep cummulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g_org = g_org.view(B, H, NT, BT, -1)
|
||||
# g = g_org.cumsum(-2).view(B, H, T, -1)
|
||||
# gs = g - g_org
|
||||
chunk_rwkv6_fwd_kernel_cum[grid](
|
||||
g_org, g, gs,
|
||||
g.stride(1), g.stride(2), g.stride(3),
|
||||
T=T, S=K, BT=BT
|
||||
)
|
||||
h = fwd_inner(
|
||||
q=q, k=k, v=v, g=g,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
h0=initial_state if initial_state is not None else None,
|
||||
ht=final_state if final_state is not None else None
|
||||
)
|
||||
A = q.new_zeros(NK, B, H, T, BT)
|
||||
grid = (NK, NT * NC * NC, B * H)
|
||||
chunk_rwkv6_fwd_kernel_intra[grid](
|
||||
q, k, g, gs, u, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
scale,
|
||||
H=H, T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, DK=K,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
A = A.sum(0, dtype=A.dtype)
|
||||
o = torch.empty_like(v)
|
||||
|
||||
grid = (NV, NT, B * H)
|
||||
chunk_rwkv6_fwd_kernel_inter[grid](
|
||||
q, v, gs, h, o, A,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
if checkpoint_level > 1:
|
||||
del h
|
||||
h, initial_state = None, None
|
||||
del g, gs
|
||||
ctx.save_for_backward(q, k, v, g_org, u, h, initial_state, A)
|
||||
ctx.BT = BT
|
||||
ctx.scale = scale
|
||||
ctx.checkpoint_level = checkpoint_level
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, dht=None):
|
||||
q, k, v, g, u, h, initial_state, A = ctx.saved_tensors
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
BT, BC = ctx.BT, 16
|
||||
BK = min(64, triton.next_power_of_2(K))
|
||||
BV = min(64, triton.next_power_of_2(V))
|
||||
NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
|
||||
NK = triton.cdiv(K, BK)
|
||||
num_warps = 4 if BK == 64 else 2
|
||||
num_stages = 1
|
||||
|
||||
def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
h = q.new_empty(B, H, NT * K, V)
|
||||
grid = (NV, NK, B * H)
|
||||
chunk_rwkv6_fwd_kernel_h[grid](
|
||||
k, v, g, h, h0, ht,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
STORE_FINAL_STATE=ht is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return h
|
||||
|
||||
def bwd_inner(q, g, gs, h0, do, B, H, T, K, V, BT, BK, BV, NT, scale):
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
dh = q.new_empty(B, H, NT * K, V)
|
||||
dh0 = torch.empty_like(h0) if h0 is not None else None
|
||||
grid = (NK, NV, B * H)
|
||||
chunk_rwkv6_bwd_kernel_dh[grid](
|
||||
q, g, gs, do, dh, dh0,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
do.stride(1), do.stride(2), do.stride(3),
|
||||
dh.stride(1), dh.stride(2), dh.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
USE_INITIAL_STATE=h0 is not None,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
return dh, dh0
|
||||
|
||||
# recompute cumulative log decays.
|
||||
g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float)
|
||||
def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H))
|
||||
# keep cummulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_rwkv6_fwd_kernel_cum[grid](
|
||||
g_org, g, gs,
|
||||
g.stride(1), g.stride(2), g.stride(3),
|
||||
T=T, S=K, BT=BT
|
||||
)
|
||||
|
||||
# rerun the forward pass to get h if checkpoint_level >= 1
|
||||
if ctx.checkpoint_level == 1:
|
||||
h = fwd_inner(
|
||||
q=q, k=k, v=v, g=g,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
h0=initial_state if initial_state is not None else None,
|
||||
ht=None
|
||||
)
|
||||
|
||||
scale = ctx.scale
|
||||
dh, dh0 = bwd_inner(
|
||||
q, g, gs, initial_state, do,
|
||||
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
|
||||
scale=scale
|
||||
)
|
||||
dq = torch.empty_like(q, dtype=torch.float)
|
||||
dk = torch.empty_like(k, dtype=torch.float)
|
||||
dv = v.new_empty(NK, *v.shape)
|
||||
dA = q.new_zeros(B, H, T, BT)
|
||||
grid = (NK, NT, B * H)
|
||||
chunk_rwkv6_bwd_kernel_inter[grid](
|
||||
k, v, h, g, gs, A, do, dh, dq, dk, dv, dA,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3),
|
||||
h.stride(1), h.stride(2), h.stride(3),
|
||||
scale,
|
||||
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
dv = dv.sum(0, dtype=dv.dtype)
|
||||
grid = (NK, NT * NC, B * H)
|
||||
chunk_rwkv6_bwd_kernel_intra[grid](
|
||||
q, k, g, gs, dA, dq, dk,
|
||||
k.stride(1), k.stride(2), k.stride(3),
|
||||
T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
# TODO: fuse?
|
||||
dg = (dq * q)[:, :, 1:] - (dk * k)[:, :, 0:-1]
|
||||
dg = torch.nn.functional.pad(dg, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
|
||||
dg = chunk_reversed_cumsum_fwd(dg).to(g)
|
||||
# equivalent to the following pytorch code.
|
||||
# du = ((do * v).sum(-1)[..., None] * k * q * scale).sum(-2).to(u)
|
||||
# dq += ((do * v).sum(-1)[..., None] * k * scale * u[:, :, None, :])
|
||||
# dk += ((do * v).sum(-1)[..., None] * q * scale * u[:, :, None, :])
|
||||
BT = 64
|
||||
grid = (triton.cdiv(T, BT), B * H)
|
||||
du = torch.empty_like(g, dtype=torch.float)
|
||||
post_process_grad[grid](
|
||||
q, k, v, u, do, dk, dq, du, scale,
|
||||
q.stride(1), q.stride(2), q.stride(3),
|
||||
v.stride(1), v.stride(2), v.stride(3), H=H,
|
||||
T=T, BT=BT, K=K, V=V, BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V),
|
||||
num_warps=4
|
||||
)
|
||||
du = du.sum([0, 2])
|
||||
return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None
|
||||
|
||||
|
||||
def chunk_rwkv6(
|
||||
r: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
scale: Optional[int] = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
checkpoint_level: Optional[int] = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
r (torch.Tensor):
|
||||
reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
|
||||
k (torch.Tensor):
|
||||
keys of shape `(B, H, T, K)`
|
||||
v (torch.Tensor):
|
||||
values of shape `(B, H, T, V)`
|
||||
w (torch.Tensor):
|
||||
data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
|
||||
u (torch.Tensor):
|
||||
bonus of shape `(H, K)`
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RWKV6 attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `(B, H, K, V)`. Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
|
||||
checkpoint_level (Optional[int]):
|
||||
Checkpointing level; higher values will save more memories and do more recomputations during backward.
|
||||
Default: `0`:
|
||||
- Level `0`: store forward hidden states for backprop.
|
||||
- Level `1`: recompute the forward hidden states during backward.
|
||||
"""
|
||||
assert checkpoint_level in [0, 1]
|
||||
if scale is None:
|
||||
scale = r.shape[-1] ** -0.5
|
||||
o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level)
|
||||
return o, final_state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
|
||||
B = 4
|
||||
H = 4
|
||||
L = 1024
|
||||
K = 100
|
||||
V = 120
|
||||
|
||||
torch.manual_seed(0)
|
||||
dtype = torch.float32
|
||||
q = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True)
|
||||
k = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True)
|
||||
v = torch.randn(B, H, L, V).cuda().to(dtype).requires_grad_(True)
|
||||
w = (-torch.randn(B, H, L, K).exp()).cuda().to(torch.float32).requires_grad_(True)
|
||||
u = torch.randn(H, K).cuda().to(dtype).requires_grad_(True)
|
||||
h0 = torch.randn(B, H, K, V).cuda().to(dtype).requires_grad_(True)
|
||||
do = torch.rand_like(v).cuda()
|
||||
o, ht = fused_recurrent_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True)
|
||||
o.backward(do)
|
||||
dq, q.grad = q.grad.clone(), None
|
||||
dk, k.grad = k.grad.clone(), None
|
||||
dv, v.grad = v.grad.clone(), None
|
||||
dw, w.grad = w.grad.clone(), None
|
||||
du, u.grad = u.grad.clone(), None
|
||||
dh0, h0.grad = h0.grad.clone(), None
|
||||
o2, ht2 = chunk_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True)
|
||||
o2.backward(do)
|
||||
torch.testing.assert_close(o, o2, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(ht, ht2, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(q.grad, dq, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(k.grad, dk, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(v.grad, dv, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(w.grad, dw, rtol=0, atol=1e-4)
|
||||
torch.testing.assert_close(u.grad, du, rtol=0, atol=2e-4)
|
||||
torch.testing.assert_close(h0.grad, dh0, rtol=0, atol=2e-4)
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
# argument names to use as an x-axis for the plot
|
||||
x_names=['T'],
|
||||
# different possible values for `x_name`
|
||||
x_vals=[128 * 2 ** i for i in range(0, 8)],
|
||||
# argument name whose value corresponds to a different line in the plot
|
||||
line_arg='provider',
|
||||
# possible values for `line_arg``
|
||||
line_vals=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'],
|
||||
# label name for the lines
|
||||
line_names=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
|
||||
ylabel="Execution Time (ms)", # label name for the y-axis
|
||||
# name for the plot. Used also as a file name for saving the plot.
|
||||
plot_name="Performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(T, provider):
|
||||
device = 'cuda'
|
||||
dtype = torch.bfloat16
|
||||
requires_grad = True
|
||||
B, H, K = 16, 4, 128
|
||||
|
||||
q = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
|
||||
k = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
|
||||
v = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype)
|
||||
w = F.logsigmoid(torch.randn(B, H, T, K)).to(dtype=dtype, device=device).requires_grad_(True)
|
||||
u = torch.randn(H, K, device=device, requires_grad=requires_grad, dtype=dtype)
|
||||
|
||||
do = torch.ones_like(q, dtype=dtype)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
results = 0, 0, 0
|
||||
if provider == 'recurrent':
|
||||
results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u), quantiles=quantiles)
|
||||
if provider == 'chunk':
|
||||
results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u), quantiles=quantiles)
|
||||
if provider == 'recurrent_bwd':
|
||||
results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u)
|
||||
[0].backward(do), quantiles=quantiles)
|
||||
if provider == 'chunk_bwd':
|
||||
results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u)[0].backward(do), quantiles=quantiles)
|
||||
return results
|
||||
benchmark.run(print_data=True)
|
||||
79
finetune/lora/v6/fla/ops/rwkv6/chunk_naive.py
vendored
Normal file
79
finetune/lora/v6/fla/ops/rwkv6/chunk_naive.py
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from fla.ops.rwkv6.chunk import chunk_rwkv6
|
||||
from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
|
||||
|
||||
|
||||
def naive_chunk_rwkv6(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
u,
|
||||
chunk_size=32,
|
||||
initial_state=None,
|
||||
output_final_state=True,
|
||||
):
|
||||
assert q.shape[-2] % chunk_size == 0
|
||||
orig_dtype = q.dtype
|
||||
num_chunk = q.shape[-2] // chunk_size
|
||||
u = u.unsqueeze(0)
|
||||
|
||||
q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w))
|
||||
|
||||
w_cumsum = w.cumsum(-2)
|
||||
|
||||
kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp()
|
||||
wkv = kw.transpose(-1, -2) @ v
|
||||
|
||||
wkv_new = torch.zeros_like(wkv)
|
||||
|
||||
for i in range(num_chunk - 1):
|
||||
wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i]
|
||||
|
||||
o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp()))
|
||||
|
||||
o_intra = torch.zeros_like(o_inter)
|
||||
for i in range(chunk_size):
|
||||
attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1)
|
||||
mask = (torch.arange(0, chunk_size) < i).to(attn.device)
|
||||
attn.masked_fill_(~mask, 0)
|
||||
intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2)
|
||||
intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i]
|
||||
o_intra[:, :, :, i] = intra_inter_o + intra_intra_o
|
||||
o = o_inter + o_intra
|
||||
return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
B = 4
|
||||
H = 4
|
||||
L = 1024
|
||||
D = 100
|
||||
dtype = torch.bfloat16
|
||||
require_grad = True
|
||||
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad)
|
||||
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad)
|
||||
v = torch.randn(B, H, L, 2*D).cuda().to(dtype).requires_grad_(require_grad)
|
||||
w = torch.nn.functional.logsigmoid(torch.randn(B, H, L, D)).cuda().to(dtype).requires_grad_(require_grad)
|
||||
u = (torch.randn(H, D).cuda().to(dtype)).requires_grad_(require_grad)
|
||||
do = torch.rand_like(v).cuda()
|
||||
o2, _ = chunk_rwkv6(q, k, v, w.clone(), u)
|
||||
o, _ = fused_recurrent_rwkv6(q, k, v, w, u, scale=1.0)
|
||||
o.backward(do)
|
||||
dq, q.grad = q.grad.clone(), None
|
||||
dk, k.grad = k.grad.clone(), None
|
||||
dv, v.grad = v.grad.clone(), None
|
||||
dw, w.grad = w.grad.clone(), None
|
||||
du, u.grad = u.grad.clone(), None
|
||||
print((o - o2).abs().max())
|
||||
o2.backward(do)
|
||||
print((o-o2).abs().max())
|
||||
print((q.grad - dq).abs().max())
|
||||
print((k.grad - dk).abs().max())
|
||||
print((v.grad - dv).abs().max())
|
||||
print((w.grad - dw).abs().max())
|
||||
print((u.grad - du).abs().max())
|
||||
378
finetune/lora/v6/fla/ops/rwkv6/recurrent_fuse.py
vendored
Normal file
378
finetune/lora/v6/fla/ops/rwkv6/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,378 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2024, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from fla.ops.utils import chunk_reversed_cumsum_fwd
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_rwkv6_fwd_kernel(
|
||||
q, # query [B, H, T, K]
|
||||
k, # key [B, H, T, K]
|
||||
v, # value [B, H, T, V]
|
||||
w, # log gate [B, H, T, K]
|
||||
u, # bonus [B, H, K]
|
||||
o, # output [B, H, T, V]
|
||||
# initial hidden state initialization [B, H, K, V]
|
||||
h0,
|
||||
ht, # final hidden state [B, H, K, V]
|
||||
s_k_h, # stride size: T * K
|
||||
s_v_h, # stride size: T * V
|
||||
scale, # K ** -0.5
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
|
||||
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
|
||||
|
||||
mask_bk = (i_k * BK + tl.arange(0, BK)) < K
|
||||
mask_bv = (i_v * BV + tl.arange(0, BV)) < V
|
||||
mask_kv = mask_bv[:, None] & mask_bk[None, :]
|
||||
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
|
||||
for _ in range(0, T):
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_w = tl.exp(b_w)
|
||||
b_kv = b_k[None, :] * b_v[:, None]
|
||||
b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]
|
||||
b_o = tl.sum(b_o, axis=1)
|
||||
b_h = b_h * b_w[None, :]
|
||||
b_h += b_kv
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
|
||||
p_q += -K if REVERSE else K
|
||||
p_k += -K if REVERSE else K
|
||||
p_o += -V if REVERSE else V
|
||||
p_v += -V if REVERSE else V
|
||||
p_w += -K if REVERSE else K
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
|
||||
@triton.jit
|
||||
def fused_recurrent_rwkv6_bwd_kernel_dq(
|
||||
# B: B, H: H, T: T, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
k, # key [B, H, T, V]
|
||||
v, # value [B, H, T, V]
|
||||
w, # log gate [B, H, T, K]
|
||||
u, # bonus [B, H, K]
|
||||
|
||||
do, # gradient of output [B, H, T, V]
|
||||
dq, # gradient of query [NV, B, H, T, K]
|
||||
dq_aux, # gradient of query_aux [NV, B, H, T, K]
|
||||
|
||||
# initial hidden state initialization [B, H, K, V]
|
||||
h0,
|
||||
|
||||
s_k_h, # stride size: T * K
|
||||
s_v_h, # stride size: T * V
|
||||
|
||||
scale, # K ** -0.5
|
||||
B: tl.constexpr, # B
|
||||
H: tl.constexpr, # H
|
||||
T: tl.constexpr, # T
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
K: tl.constexpr, # K
|
||||
V: tl.constexpr, # V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)
|
||||
p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_dq_aux = dq_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)
|
||||
p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
|
||||
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < K
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < V
|
||||
mask_kv = mask_bv[:, None] & mask_bk[None, :]
|
||||
b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
|
||||
b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_kv = b_k[None, :] * b_v[:, None]
|
||||
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_w = tl.exp(b_w)
|
||||
h_q = b_h * b_do[:, None]
|
||||
b_dq = tl.sum(h_q + b_kv * b_u[None, :] * b_do[:, None], axis=0)
|
||||
b_dq *= scale
|
||||
b_dq_aux = tl.sum(h_q, axis=0)
|
||||
b_h = b_h * b_w[None, :]
|
||||
b_h += b_kv
|
||||
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dq_aux, b_dq_aux.to(p_dq_aux.dtype.element_ty), mask=mask_bk)
|
||||
p_k += -K if REVERSE else K
|
||||
p_do += -V if REVERSE else V
|
||||
p_v += -V if REVERSE else V
|
||||
p_w += -K if REVERSE else K
|
||||
p_dq += -K if REVERSE else K
|
||||
p_dq_aux += -K if REVERSE else K
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_rwkv6_bwd_kernel_dkv(
|
||||
# B: B, H: H, T: T, D: d_head
|
||||
# NV: number of split in the V dimension. NK: number of split in the K dimension
|
||||
q, # query [B, H, T, K]
|
||||
k, # key [B, H, T, V]
|
||||
v, # value [B, H, T, V]
|
||||
w, # log gate [B, H, T, K]
|
||||
u, # bonus [B, H, K]
|
||||
|
||||
do, # gradient of output [B, H, T, V]
|
||||
dk,
|
||||
dk_aux,
|
||||
dv,
|
||||
dh0,
|
||||
|
||||
# initial hidden state initialization [B, H, K, V]
|
||||
s_k_h, # stride size: T * K
|
||||
s_v_h, # stride size: T * V
|
||||
|
||||
scale, # K ** -0.5
|
||||
B, # B
|
||||
H, # H
|
||||
T, # T
|
||||
BK: tl.constexpr, # BLOCK SIZE along the K dimension
|
||||
BV: tl.constexpr, # BLOCK SIZE along the V dimension
|
||||
K: tl.constexpr, # K
|
||||
V: tl.constexpr, # V
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction
|
||||
):
|
||||
i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_h = i_bh % H
|
||||
p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_dk_aux = dk_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0)
|
||||
p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0)
|
||||
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
mask_bk = i_k * BK + tl.arange(0, BK) < K
|
||||
mask_bv = i_v * BV + tl.arange(0, BV) < V
|
||||
mask_kv = mask_bk[:, None] & mask_bv[None, :]
|
||||
|
||||
p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK
|
||||
b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(T-1, -1, -1):
|
||||
b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
|
||||
b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)
|
||||
b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
|
||||
b_dkv = b_q[:, None] * b_do[None, :]
|
||||
b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
|
||||
tl.store(p_dk_aux, b_dk.to(p_dk_aux.dtype.element_ty), mask=mask_bk)
|
||||
b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], axis=1)
|
||||
b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], axis=0)
|
||||
|
||||
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk)
|
||||
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv)
|
||||
b_dh *= tl.exp(b_w)[:, None]
|
||||
b_dh += b_dkv
|
||||
|
||||
p_q += K if REVERSE else -K
|
||||
p_k += K if REVERSE else -K
|
||||
p_v += V if REVERSE else -V
|
||||
p_w += K if REVERSE else -K
|
||||
p_do += V if REVERSE else -V
|
||||
p_dk += K if REVERSE else -K
|
||||
p_dk_aux += K if REVERSE else -K
|
||||
p_dv += V if REVERSE else -V
|
||||
|
||||
if USE_INITIAL_STATE:
|
||||
p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
|
||||
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv)
|
||||
|
||||
|
||||
class FusedRecurrentRWKV6Function(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_fwd
|
||||
def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):
|
||||
# alias
|
||||
q = r
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
|
||||
BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(B, H, K, V)
|
||||
else:
|
||||
final_state = None
|
||||
|
||||
o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
|
||||
grid = (NV, NK, B * H)
|
||||
fused_recurrent_rwkv6_fwd_kernel[grid](
|
||||
q, k, v, w, u, o, initial_state, final_state,
|
||||
k.stride(1),
|
||||
v.stride(1),
|
||||
scale,
|
||||
B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None,
|
||||
REVERSE=reverse,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages
|
||||
)
|
||||
|
||||
o = o.sum(0)
|
||||
ctx.save_for_backward(q, k, v, w, u, initial_state, o)
|
||||
ctx.scale = scale
|
||||
ctx.reverse = reverse
|
||||
# we do not need the gradient of the final state from the next chunk
|
||||
# similiar to Trunctated BPTT
|
||||
if final_state is not None:
|
||||
final_state = final_state.detach()
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
@custom_bwd
|
||||
def backward(ctx, do, d_final_state=None):
|
||||
q, k, v, w, u, initial_state, o = ctx.saved_tensors
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
scale = ctx.scale
|
||||
|
||||
BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
num_stages = 1
|
||||
num_warps = 1
|
||||
dq = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
|
||||
dq_aux = torch.empty_like(dq)
|
||||
grid = (NV, NK, B * H)
|
||||
|
||||
fused_recurrent_rwkv6_bwd_kernel_dq[grid](
|
||||
k, v, w, u, do, dq, dq_aux, initial_state,
|
||||
q.stride(1),
|
||||
v.stride(1),
|
||||
scale,
|
||||
B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
REVERSE=ctx.reverse,
|
||||
)
|
||||
dq = dq.sum(0).to(q)
|
||||
dq_aux = dq_aux.sum(0)
|
||||
|
||||
BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
|
||||
dk = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
|
||||
dk_aux = q.new_empty(NV, B, H, T, K, dtype=torch.float32)
|
||||
dv = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
|
||||
dh0 = initial_state.new_empty(B, H, K, V) if initial_state is not None else None
|
||||
grid = (NV, NK, B * H)
|
||||
fused_recurrent_rwkv6_bwd_kernel_dkv[grid](
|
||||
q, k, v, w, u, do, dk, dk_aux, dv, dh0,
|
||||
q.stride(1),
|
||||
v.stride(1),
|
||||
scale,
|
||||
B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
REVERSE=ctx.reverse,
|
||||
)
|
||||
dk = dk.sum(0).to(k)
|
||||
dv = dv.sum(0).to(v)
|
||||
dk_aux = dk_aux.sum(0)
|
||||
|
||||
dw = (dq_aux * q * scale)[:, :, 1:] - (dk_aux * k)[:, :, 0:-1]
|
||||
dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0)
|
||||
dw = chunk_reversed_cumsum_fwd(dw).to(w)
|
||||
|
||||
du = ((do * v).sum(-1)[..., None] * k * q * scale).sum([0, -2]).to(u)
|
||||
return dq, dk, dv, dw, du, None, dh0, None, None
|
||||
|
||||
|
||||
def fused_recurrent_rwkv6(
|
||||
r: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
scale: int = -1,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
causal: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
r (torch.Tensor):
|
||||
reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
|
||||
k (torch.Tensor):
|
||||
keys of shape `(B, H, T, K)`
|
||||
v (torch.Tensor):
|
||||
values of shape `(B, H, T, V)`
|
||||
w (torch.Tensor):
|
||||
data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
|
||||
u (torch.Tensor):
|
||||
bonus of shape `(H, K)`
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RWKV6 attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `(B, H, K, V)`. Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
|
||||
"""
|
||||
if scale == -1:
|
||||
scale = r.shape[-1] ** -0.5
|
||||
o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
102
finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py
vendored
Normal file
102
finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def naive_recurrent_rwkv6(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
scale: Optional[float] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: Optional[bool] = False
|
||||
):
|
||||
orig_dtype = q.dtype
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
|
||||
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
|
||||
o = torch.zeros_like(v)
|
||||
|
||||
if scale is None:
|
||||
scale = K ** -0.5
|
||||
|
||||
if initial_state is not None:
|
||||
h += initial_state
|
||||
|
||||
for i in range(T):
|
||||
q_i = q[:, :, i, :] * scale
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i, :]
|
||||
w_i = w[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
|
||||
o[:, :, i] = o_i.sum(-2)
|
||||
h = h * w_i[..., None] + kv_i
|
||||
ht = h if output_final_state else None
|
||||
return o.to(orig_dtype), ht
|
||||
|
||||
|
||||
def naive_recurrent_rwkv6_bwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
u,
|
||||
o,
|
||||
do,
|
||||
initial_state=None,
|
||||
output_final_state=False
|
||||
):
|
||||
q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
|
||||
dq = torch.zeros_like(q)
|
||||
dq_aux = torch.zeros_like(q)
|
||||
|
||||
if initial_state is not None:
|
||||
h += initial_state
|
||||
|
||||
for i in range(T):
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i]
|
||||
w_i = w[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
h_i = (h + u[None, ..., None] * kv_i)
|
||||
dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
|
||||
dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
|
||||
dq[:, :, i] = dq_i
|
||||
dq_aux[:, :, i] = dq_aux_i
|
||||
h = h * w_i[..., None] + kv_i
|
||||
|
||||
du = torch.zeros_like(u)
|
||||
dh = torch.zeros_like(h)
|
||||
dk = torch.zeros_like(k)
|
||||
dk_aux = torch.zeros_like(k)
|
||||
dv = torch.zeros_like(v)
|
||||
|
||||
for i in range(T - 1, -1, -1):
|
||||
d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i]
|
||||
du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
|
||||
du += du_i
|
||||
dk_i = (dh * v_i[..., None, :]).sum(-1)
|
||||
dk_aux[:, :, i] = dk_i
|
||||
dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
|
||||
dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
|
||||
dv_i += (dh * k_i[..., None]).sum(-2)
|
||||
|
||||
dk[:, :, i] = dk_i
|
||||
dv[:, :, i] = dv_i
|
||||
dh = dh * w[:, :, i, :, None].exp() + d_kv_i
|
||||
|
||||
# dw = q * dq_aux - k * dk_aux
|
||||
dw = torch.zeros_like(w)
|
||||
for i in range(T - 2, -1, -1):
|
||||
dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
|
||||
|
||||
return dq, dk, dv, dw, du
|
||||
Reference in New Issue
Block a user