RWKV-Runner/finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py

103 lines
2.9 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
from typing import Optional
import torch
def naive_recurrent_rwkv6(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
scale: Optional[float] = None,
initial_state: Optional[torch.Tensor] = None,
output_final_state: Optional[bool] = False
):
orig_dtype = q.dtype
B, H, T, K, V = *q.shape, v.shape[-1]
q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
o = torch.zeros_like(v)
if scale is None:
scale = K ** -0.5
if initial_state is not None:
h += initial_state
for i in range(T):
q_i = q[:, :, i, :] * scale
k_i = k[:, :, i]
v_i = v[:, :, i, :]
w_i = w[:, :, i].exp()
kv_i = k_i[..., None] * v_i[..., None, :]
o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
o[:, :, i] = o_i.sum(-2)
h = h * w_i[..., None] + kv_i
ht = h if output_final_state else None
return o.to(orig_dtype), ht
def naive_recurrent_rwkv6_bwd(
q,
k,
v,
w,
u,
o,
do,
initial_state=None,
output_final_state=False
):
q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
B, H, T, K, V = *q.shape, v.shape[-1]
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
dq = torch.zeros_like(q)
dq_aux = torch.zeros_like(q)
if initial_state is not None:
h += initial_state
for i in range(T):
k_i = k[:, :, i]
v_i = v[:, :, i]
w_i = w[:, :, i].exp()
kv_i = k_i[..., None] * v_i[..., None, :]
h_i = (h + u[None, ..., None] * kv_i)
dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
dq[:, :, i] = dq_i
dq_aux[:, :, i] = dq_aux_i
h = h * w_i[..., None] + kv_i
du = torch.zeros_like(u)
dh = torch.zeros_like(h)
dk = torch.zeros_like(k)
dk_aux = torch.zeros_like(k)
dv = torch.zeros_like(v)
for i in range(T - 1, -1, -1):
d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
k_i = k[:, :, i]
v_i = v[:, :, i]
du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
du += du_i
dk_i = (dh * v_i[..., None, :]).sum(-1)
dk_aux[:, :, i] = dk_i
dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
dv_i += (dh * k_i[..., None]).sum(-2)
dk[:, :, i] = dk_i
dv[:, :, i] = dv_i
dh = dh * w[:, :, i, :, None].exp() + d_kv_i
# dw = q * dq_aux - k * dk_aux
dw = torch.zeros_like(w)
for i in range(T - 2, -1, -1):
dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
return dq, dk, dv, dw, du