# -*- coding: utf-8 -*- # adopted from https://github.com/codekansas/rwkv from typing import Any, cast import torch import triton import triton.language as tl from torch import Tensor from torch.autograd.function import Function, FunctionCtx, once_differentiable def get_block_size_c(chans: int) -> int: if chans < 32: return 32 if chans < 64: return 64 return 128 @triton.jit def fused_recurrent_rwkv4_forward_kernel( # W w_ptr, w_s_c, # U u_ptr, u_s_c, # K k_ptr, k_s_b, k_s_t, k_s_c, # V v_ptr, v_s_b, v_s_t, v_s_c, # State state_ptr, state_s_b, state_s_abe, state_s_c, # WKV wkv_ptr, wkv_s_b, wkv_s_t, wkv_s_c, # Output state state_out_ptr, state_out_s_b, state_out_s_abe, state_out_s_t, state_out_s_c, # Params chans, tsz, BLOCK_SIZE_C: tl.constexpr, ): # Parallelize over the batch dimension. b_idx = tl.program_id(0) c_idx = tl.program_id(1) cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) cmask = cs < chans # Pointers to the batch (and possibly channel) for the input tensors. k_ptr = k_ptr + b_idx * k_s_b v_ptr = v_ptr + b_idx * v_s_b alpha_ptr = state_ptr + b_idx * state_s_b beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe # Pointers to the batch (and possibly channel) for the output tensors. wkv_ptr = wkv_ptr + b_idx * wkv_s_b alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe # Loads parameters. alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32) beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32) eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32) w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32) u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32) for t in range(tsz): kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32) vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32) ukt = u + kt tau = tl.maximum(ukt, eps) e1a = tl.exp(eps - tau) e2a = tl.exp(ukt - tau) wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a) tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask) w_eps = w + eps eps = tl.maximum(w_eps, kt) e1b = tl.exp(w_eps - eps) e2b = tl.exp(kt - eps) alpha = e1b * alpha + e2b * vt beta = e1b * beta + e2b tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask) tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask) tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask) def fused_recurrent_rwkv4_forward( w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor, ) -> tuple[Tensor, Tensor]: (bsz, tsz, chans) = k.shape # New tensors to output. wkvs = k.new_empty(bsz, tsz, chans) state_out = k.new_empty(bsz, 3, tsz, chans) # Constants. block_size_c = get_block_size_c(chans) def grid(meta: dict[str, Any]) -> tuple[int, ...]: return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) fused_recurrent_rwkv4_forward_kernel[grid]( # W w, w.stride(0), # U u, u.stride(0), # K k, k.stride(0), k.stride(1), k.stride(2), # V v, v.stride(0), v.stride(1), v.stride(2), # State state, state.stride(0), state.stride(1), state.stride(3), # WKV wkvs, wkvs.stride(0), wkvs.stride(1), wkvs.stride(2), # Output state state_out, state_out.stride(0), state_out.stride(1), state_out.stride(2), state_out.stride(3), # Params chans, tsz, BLOCK_SIZE_C=block_size_c, ) state_out = torch.cat((state, state_out), dim=2) return wkvs, state_out @triton.jit def fused_recurrent_rwkv4_backward_kernel( # W w_ptr, w_s_c, # U u_ptr, u_s_c, # K k_ptr, k_s_b, k_s_t, k_s_c, # V v_ptr, v_s_b, v_s_t, v_s_c, # State state_ptr, state_s_b, state_s_abe, state_s_t, state_s_c, # WKV grad gwkv_ptr, gwkv_s_b, gwkv_s_t, gwkv_s_c, # Output state grad gstate_out_ptr, gstate_out_s_b, gstate_out_s_abe, gstate_out_s_c, # W grad gw_ptr, gw_s_c, # U grad gu_ptr, gu_s_c, # K grad gk_ptr, gk_s_b, gk_s_t, gk_s_c, # V grad gv_ptr, gv_s_b, gv_s_t, gv_s_c, # State grad gstate_ptr, gstate_s_b, gstate_s_abe, gstate_s_c, # Params tsz, chans, BLOCK_SIZE_C: tl.constexpr, ): # Parallelize over the batch dimension. b_idx = tl.program_id(0) c_idx = tl.program_id(1) cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) cmask = cs < chans # Pointers to the batch (and possibly channel) for the input tensors. k_ptr = k_ptr + b_idx * k_s_b v_ptr = v_ptr + b_idx * v_s_b alpha_ptr = state_ptr + b_idx * state_s_b beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe # Pointers to the batch (and possibly channel) for the output tensors. gk_ptr = gk_ptr + b_idx * gk_s_b gv_ptr = gv_ptr + b_idx * gv_s_b # Pointers to gradients which were recieved by the function. gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe # Loads parameters. galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32) u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32) # Gradient accumulators. gw = tl.zeros_like(w) gu = tl.zeros_like(u) alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) for t in range(tsz): tc = tsz - t - 1 kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32) vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32) alpha_curr = alpha_prev beta_curr = beta_prev eps_curr = eps_prev alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) ukt = u + kt tau = tl.maximum(ukt, eps_prev) e1 = tl.exp(eps_prev - tau) e2 = tl.exp(ukt - tau) euke = tl.exp(ukt + eps_prev - 2 * tau) denom = e1 * beta_prev + e2 denom_sq = denom * denom gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32) # Backpropagates wkv gradients. guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq gu += guk gk = guk gv = gwkvt * e2 / denom galpha_wkv = gwkvt * e1 / denom gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq geps_wkv_denom = e1 * beta_prev + e2 geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom) e1 = tl.exp(w + eps_prev - eps_curr) e2 = tl.exp(kt - eps_curr) # Backpropagates alpha gradients. galpha_we = galpha * e1 * alpha_prev gw += galpha_we gk += galpha * e2 * vt gv += galpha * e2 geps += galpha * -alpha_curr # Backpropagates beta gradients. gbeta_we = gbeta * e1 * beta_prev gw += gbeta_we gk += gbeta * e2 geps += gbeta * -beta_curr # Backpropagates epsilon gradients. geps_mask = w + eps_prev > kt geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps)) gw += geps_we gk += tl.where(geps_mask, tl.zeros_like(geps), geps) # Stores the gradients for k and v. tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask) tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask) # Computes new gradients for alpha and beta. galpha = galpha * e1 + galpha_wkv gbeta = gbeta * e1 + gbeta_wkv geps = galpha_we + gbeta_we + geps_we + geps_wkv # Stores final gradients for alpha and beta. galpha_ptr = gstate_ptr + b_idx * gstate_s_b gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask) tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask) tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask) # Stores final gradients for w and u. gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32) gw_temp += gw tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask) gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32) gu_temp += gu tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask) def fused_recurrent_rwkv4_backward( w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor, grad_wkv: Tensor, grad_state: Tensor, ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: bsz, tsz, chans = k.shape gw = torch.zeros_like(w) # New tensors to output. gu = torch.zeros_like(u) gk = torch.empty_like(k) gv = torch.empty_like(v) gstate = k.new_empty(bsz, 3, 1, chans) block_size_c = get_block_size_c(chans) # Constants. def grid(meta: dict[str, Any]) -> tuple[int, ...]: return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) fused_recurrent_rwkv4_backward_kernel[grid]( # W w, w.stride(0), # U u, u.stride(0), # K k, k.stride(0), k.stride(1), k.stride(2), # V v, v.stride(0), v.stride(1), v.stride(2), # State state, state.stride(0), state.stride(1), state.stride(2), state.stride(3), # WKV grad grad_wkv, grad_wkv.stride(0), grad_wkv.stride(1), grad_wkv.stride(2), # Output state grad grad_state, grad_state.stride(0), grad_state.stride(1), grad_state.stride(3), # W grad gw, gw.stride(0), # U grad gu, gu.stride(0), # K grad gk, gk.stride(0), gk.stride(1), gk.stride(2), # V grad gv, gv.stride(0), gv.stride(1), gv.stride(2), # State grad gstate, gstate.stride(0), gstate.stride(1), gstate.stride(3), # Params tsz, chans, BLOCK_SIZE_C=block_size_c, ) return gw, gu, gk, gv, gstate class FusedRecurrentRWKV4Function(Function): @staticmethod def forward( ctx: FunctionCtx, w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor, ) -> tuple[Tensor, Tensor]: ctx.input_dtype = k.dtype if ( w.device.type != "cuda" or u.device.type != "cuda" or k.device.type != "cuda" or v.device.type != "cuda" ): raise ValueError( "Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices." ) w = -torch.exp(w.float().contiguous()) if k.dtype == torch.float16: u = u.float() k = k.float() v = v.float() u = u.contiguous() k = k.contiguous() v = v.contiguous() wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state) ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1]) return wkv, state_out[:, :, -1:] @staticmethod @once_differentiable def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors) gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate) return gw, gu, gk, gv, gstate def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]: return FusedRecurrentRWKV4Function.apply(w, u, k, v, state)