485 lines
13 KiB
Python
Vendored
485 lines
13 KiB
Python
Vendored
# -*- coding: utf-8 -*-
|
|
# adopted from https://github.com/codekansas/rwkv
|
|
|
|
from typing import Any, cast
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from torch import Tensor
|
|
from torch.autograd.function import Function, FunctionCtx, once_differentiable
|
|
|
|
|
|
def get_block_size_c(chans: int) -> int:
|
|
if chans < 32:
|
|
return 32
|
|
if chans < 64:
|
|
return 64
|
|
return 128
|
|
|
|
|
|
@triton.jit
|
|
def fused_recurrent_rwkv4_forward_kernel(
|
|
# W
|
|
w_ptr,
|
|
w_s_c,
|
|
# U
|
|
u_ptr,
|
|
u_s_c,
|
|
# K
|
|
k_ptr,
|
|
k_s_b,
|
|
k_s_t,
|
|
k_s_c,
|
|
# V
|
|
v_ptr,
|
|
v_s_b,
|
|
v_s_t,
|
|
v_s_c,
|
|
# State
|
|
state_ptr,
|
|
state_s_b,
|
|
state_s_abe,
|
|
state_s_c,
|
|
# WKV
|
|
wkv_ptr,
|
|
wkv_s_b,
|
|
wkv_s_t,
|
|
wkv_s_c,
|
|
# Output state
|
|
state_out_ptr,
|
|
state_out_s_b,
|
|
state_out_s_abe,
|
|
state_out_s_t,
|
|
state_out_s_c,
|
|
# Params
|
|
chans,
|
|
tsz,
|
|
BLOCK_SIZE_C: tl.constexpr,
|
|
):
|
|
# Parallelize over the batch dimension.
|
|
b_idx = tl.program_id(0)
|
|
c_idx = tl.program_id(1)
|
|
|
|
cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
|
|
cmask = cs < chans
|
|
|
|
# Pointers to the batch (and possibly channel) for the input tensors.
|
|
k_ptr = k_ptr + b_idx * k_s_b
|
|
v_ptr = v_ptr + b_idx * v_s_b
|
|
alpha_ptr = state_ptr + b_idx * state_s_b
|
|
beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
|
|
eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
|
|
|
|
# Pointers to the batch (and possibly channel) for the output tensors.
|
|
wkv_ptr = wkv_ptr + b_idx * wkv_s_b
|
|
alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b
|
|
beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe
|
|
eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe
|
|
|
|
# Loads parameters.
|
|
alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
|
|
beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
|
|
eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
|
|
w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)
|
|
u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)
|
|
|
|
for t in range(tsz):
|
|
kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)
|
|
vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)
|
|
|
|
ukt = u + kt
|
|
tau = tl.maximum(ukt, eps)
|
|
e1a = tl.exp(eps - tau)
|
|
e2a = tl.exp(ukt - tau)
|
|
wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)
|
|
tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)
|
|
|
|
w_eps = w + eps
|
|
eps = tl.maximum(w_eps, kt)
|
|
e1b = tl.exp(w_eps - eps)
|
|
e2b = tl.exp(kt - eps)
|
|
alpha = e1b * alpha + e2b * vt
|
|
beta = e1b * beta + e2b
|
|
tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)
|
|
tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)
|
|
tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)
|
|
|
|
|
|
def fused_recurrent_rwkv4_forward(
|
|
w: Tensor,
|
|
u: Tensor,
|
|
k: Tensor,
|
|
v: Tensor,
|
|
state: Tensor,
|
|
) -> tuple[Tensor, Tensor]:
|
|
(bsz, tsz, chans) = k.shape
|
|
|
|
# New tensors to output.
|
|
wkvs = k.new_empty(bsz, tsz, chans)
|
|
state_out = k.new_empty(bsz, 3, tsz, chans)
|
|
|
|
# Constants.
|
|
block_size_c = get_block_size_c(chans)
|
|
|
|
def grid(meta: dict[str, Any]) -> tuple[int, ...]:
|
|
return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
|
|
|
|
fused_recurrent_rwkv4_forward_kernel[grid](
|
|
# W
|
|
w,
|
|
w.stride(0),
|
|
# U
|
|
u,
|
|
u.stride(0),
|
|
# K
|
|
k,
|
|
k.stride(0),
|
|
k.stride(1),
|
|
k.stride(2),
|
|
# V
|
|
v,
|
|
v.stride(0),
|
|
v.stride(1),
|
|
v.stride(2),
|
|
# State
|
|
state,
|
|
state.stride(0),
|
|
state.stride(1),
|
|
state.stride(3),
|
|
# WKV
|
|
wkvs,
|
|
wkvs.stride(0),
|
|
wkvs.stride(1),
|
|
wkvs.stride(2),
|
|
# Output state
|
|
state_out,
|
|
state_out.stride(0),
|
|
state_out.stride(1),
|
|
state_out.stride(2),
|
|
state_out.stride(3),
|
|
# Params
|
|
chans,
|
|
tsz,
|
|
BLOCK_SIZE_C=block_size_c,
|
|
)
|
|
|
|
state_out = torch.cat((state, state_out), dim=2)
|
|
|
|
return wkvs, state_out
|
|
|
|
|
|
@triton.jit
|
|
def fused_recurrent_rwkv4_backward_kernel(
|
|
# W
|
|
w_ptr,
|
|
w_s_c,
|
|
# U
|
|
u_ptr,
|
|
u_s_c,
|
|
# K
|
|
k_ptr,
|
|
k_s_b,
|
|
k_s_t,
|
|
k_s_c,
|
|
# V
|
|
v_ptr,
|
|
v_s_b,
|
|
v_s_t,
|
|
v_s_c,
|
|
# State
|
|
state_ptr,
|
|
state_s_b,
|
|
state_s_abe,
|
|
state_s_t,
|
|
state_s_c,
|
|
# WKV grad
|
|
gwkv_ptr,
|
|
gwkv_s_b,
|
|
gwkv_s_t,
|
|
gwkv_s_c,
|
|
# Output state grad
|
|
gstate_out_ptr,
|
|
gstate_out_s_b,
|
|
gstate_out_s_abe,
|
|
gstate_out_s_c,
|
|
# W grad
|
|
gw_ptr,
|
|
gw_s_c,
|
|
# U grad
|
|
gu_ptr,
|
|
gu_s_c,
|
|
# K grad
|
|
gk_ptr,
|
|
gk_s_b,
|
|
gk_s_t,
|
|
gk_s_c,
|
|
# V grad
|
|
gv_ptr,
|
|
gv_s_b,
|
|
gv_s_t,
|
|
gv_s_c,
|
|
# State grad
|
|
gstate_ptr,
|
|
gstate_s_b,
|
|
gstate_s_abe,
|
|
gstate_s_c,
|
|
# Params
|
|
tsz,
|
|
chans,
|
|
BLOCK_SIZE_C: tl.constexpr,
|
|
):
|
|
# Parallelize over the batch dimension.
|
|
b_idx = tl.program_id(0)
|
|
c_idx = tl.program_id(1)
|
|
|
|
cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
|
|
cmask = cs < chans
|
|
|
|
# Pointers to the batch (and possibly channel) for the input tensors.
|
|
k_ptr = k_ptr + b_idx * k_s_b
|
|
v_ptr = v_ptr + b_idx * v_s_b
|
|
alpha_ptr = state_ptr + b_idx * state_s_b
|
|
beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
|
|
eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
|
|
|
|
# Pointers to the batch (and possibly channel) for the output tensors.
|
|
gk_ptr = gk_ptr + b_idx * gk_s_b
|
|
gv_ptr = gv_ptr + b_idx * gv_s_b
|
|
|
|
# Pointers to gradients which were recieved by the function.
|
|
gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b
|
|
galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b
|
|
gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe
|
|
geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe
|
|
|
|
# Loads parameters.
|
|
galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
|
|
gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
|
|
geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
|
|
w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)
|
|
u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)
|
|
|
|
# Gradient accumulators.
|
|
gw = tl.zeros_like(w)
|
|
gu = tl.zeros_like(u)
|
|
|
|
alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
|
beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
|
eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
|
|
|
for t in range(tsz):
|
|
tc = tsz - t - 1
|
|
|
|
kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)
|
|
vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)
|
|
|
|
alpha_curr = alpha_prev
|
|
beta_curr = beta_prev
|
|
eps_curr = eps_prev
|
|
|
|
alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
|
beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
|
eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
|
|
|
|
ukt = u + kt
|
|
tau = tl.maximum(ukt, eps_prev)
|
|
e1 = tl.exp(eps_prev - tau)
|
|
e2 = tl.exp(ukt - tau)
|
|
|
|
euke = tl.exp(ukt + eps_prev - 2 * tau)
|
|
|
|
denom = e1 * beta_prev + e2
|
|
denom_sq = denom * denom
|
|
|
|
gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)
|
|
|
|
# Backpropagates wkv gradients.
|
|
guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq
|
|
gu += guk
|
|
gk = guk
|
|
gv = gwkvt * e2 / denom
|
|
|
|
galpha_wkv = gwkvt * e1 / denom
|
|
gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq
|
|
geps_wkv_denom = e1 * beta_prev + e2
|
|
geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)
|
|
|
|
e1 = tl.exp(w + eps_prev - eps_curr)
|
|
e2 = tl.exp(kt - eps_curr)
|
|
|
|
# Backpropagates alpha gradients.
|
|
galpha_we = galpha * e1 * alpha_prev
|
|
gw += galpha_we
|
|
gk += galpha * e2 * vt
|
|
gv += galpha * e2
|
|
geps += galpha * -alpha_curr
|
|
|
|
# Backpropagates beta gradients.
|
|
gbeta_we = gbeta * e1 * beta_prev
|
|
gw += gbeta_we
|
|
gk += gbeta * e2
|
|
geps += gbeta * -beta_curr
|
|
|
|
# Backpropagates epsilon gradients.
|
|
geps_mask = w + eps_prev > kt
|
|
geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))
|
|
gw += geps_we
|
|
gk += tl.where(geps_mask, tl.zeros_like(geps), geps)
|
|
|
|
# Stores the gradients for k and v.
|
|
tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)
|
|
tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)
|
|
|
|
# Computes new gradients for alpha and beta.
|
|
galpha = galpha * e1 + galpha_wkv
|
|
gbeta = gbeta * e1 + gbeta_wkv
|
|
geps = galpha_we + gbeta_we + geps_we + geps_wkv
|
|
|
|
# Stores final gradients for alpha and beta.
|
|
galpha_ptr = gstate_ptr + b_idx * gstate_s_b
|
|
gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe
|
|
geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe
|
|
tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)
|
|
tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)
|
|
tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)
|
|
|
|
# Stores final gradients for w and u.
|
|
gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)
|
|
gw_temp += gw
|
|
tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)
|
|
gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)
|
|
gu_temp += gu
|
|
tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)
|
|
|
|
|
|
def fused_recurrent_rwkv4_backward(
|
|
w: Tensor,
|
|
u: Tensor,
|
|
k: Tensor,
|
|
v: Tensor,
|
|
state: Tensor,
|
|
grad_wkv: Tensor,
|
|
grad_state: Tensor,
|
|
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
bsz, tsz, chans = k.shape
|
|
|
|
gw = torch.zeros_like(w) # New tensors to output.
|
|
gu = torch.zeros_like(u)
|
|
gk = torch.empty_like(k)
|
|
gv = torch.empty_like(v)
|
|
gstate = k.new_empty(bsz, 3, 1, chans)
|
|
|
|
block_size_c = get_block_size_c(chans) # Constants.
|
|
|
|
def grid(meta: dict[str, Any]) -> tuple[int, ...]:
|
|
return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
|
|
|
|
fused_recurrent_rwkv4_backward_kernel[grid](
|
|
# W
|
|
w,
|
|
w.stride(0),
|
|
# U
|
|
u,
|
|
u.stride(0),
|
|
# K
|
|
k,
|
|
k.stride(0),
|
|
k.stride(1),
|
|
k.stride(2),
|
|
# V
|
|
v,
|
|
v.stride(0),
|
|
v.stride(1),
|
|
v.stride(2),
|
|
# State
|
|
state,
|
|
state.stride(0),
|
|
state.stride(1),
|
|
state.stride(2),
|
|
state.stride(3),
|
|
# WKV grad
|
|
grad_wkv,
|
|
grad_wkv.stride(0),
|
|
grad_wkv.stride(1),
|
|
grad_wkv.stride(2),
|
|
# Output state grad
|
|
grad_state,
|
|
grad_state.stride(0),
|
|
grad_state.stride(1),
|
|
grad_state.stride(3),
|
|
# W grad
|
|
gw,
|
|
gw.stride(0),
|
|
# U grad
|
|
gu,
|
|
gu.stride(0),
|
|
# K grad
|
|
gk,
|
|
gk.stride(0),
|
|
gk.stride(1),
|
|
gk.stride(2),
|
|
# V grad
|
|
gv,
|
|
gv.stride(0),
|
|
gv.stride(1),
|
|
gv.stride(2),
|
|
# State grad
|
|
gstate,
|
|
gstate.stride(0),
|
|
gstate.stride(1),
|
|
gstate.stride(3),
|
|
# Params
|
|
tsz,
|
|
chans,
|
|
BLOCK_SIZE_C=block_size_c,
|
|
)
|
|
|
|
return gw, gu, gk, gv, gstate
|
|
|
|
|
|
class FusedRecurrentRWKV4Function(Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx: FunctionCtx,
|
|
w: Tensor,
|
|
u: Tensor,
|
|
k: Tensor,
|
|
v: Tensor,
|
|
state: Tensor,
|
|
) -> tuple[Tensor, Tensor]:
|
|
ctx.input_dtype = k.dtype
|
|
|
|
if (
|
|
w.device.type != "cuda"
|
|
or u.device.type != "cuda"
|
|
or k.device.type != "cuda"
|
|
or v.device.type != "cuda"
|
|
):
|
|
raise ValueError(
|
|
"Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices."
|
|
)
|
|
|
|
w = -torch.exp(w.float().contiguous())
|
|
if k.dtype == torch.float16:
|
|
u = u.float()
|
|
k = k.float()
|
|
v = v.float()
|
|
u = u.contiguous()
|
|
k = k.contiguous()
|
|
v = v.contiguous()
|
|
wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state)
|
|
ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1])
|
|
return wkv, state_out[:, :, -1:]
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors)
|
|
gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate)
|
|
return gw, gu, gk, gv, gstate
|
|
|
|
|
|
def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:
|
|
return FusedRecurrentRWKV4Function.apply(w, u, k, v, state)
|