103 lines
2.9 KiB
Python
103 lines
2.9 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
from typing import Optional
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def naive_recurrent_rwkv6(
|
||
|
q: torch.Tensor,
|
||
|
k: torch.Tensor,
|
||
|
v: torch.Tensor,
|
||
|
w: torch.Tensor,
|
||
|
u: torch.Tensor,
|
||
|
scale: Optional[float] = None,
|
||
|
initial_state: Optional[torch.Tensor] = None,
|
||
|
output_final_state: Optional[bool] = False
|
||
|
):
|
||
|
orig_dtype = q.dtype
|
||
|
B, H, T, K, V = *q.shape, v.shape[-1]
|
||
|
q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
|
||
|
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
|
||
|
o = torch.zeros_like(v)
|
||
|
|
||
|
if scale is None:
|
||
|
scale = K ** -0.5
|
||
|
|
||
|
if initial_state is not None:
|
||
|
h += initial_state
|
||
|
|
||
|
for i in range(T):
|
||
|
q_i = q[:, :, i, :] * scale
|
||
|
k_i = k[:, :, i]
|
||
|
v_i = v[:, :, i, :]
|
||
|
w_i = w[:, :, i].exp()
|
||
|
kv_i = k_i[..., None] * v_i[..., None, :]
|
||
|
o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
|
||
|
o[:, :, i] = o_i.sum(-2)
|
||
|
h = h * w_i[..., None] + kv_i
|
||
|
ht = h if output_final_state else None
|
||
|
return o.to(orig_dtype), ht
|
||
|
|
||
|
|
||
|
def naive_recurrent_rwkv6_bwd(
|
||
|
q,
|
||
|
k,
|
||
|
v,
|
||
|
w,
|
||
|
u,
|
||
|
o,
|
||
|
do,
|
||
|
initial_state=None,
|
||
|
output_final_state=False
|
||
|
):
|
||
|
q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
|
||
|
B, H, T, K, V = *q.shape, v.shape[-1]
|
||
|
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
|
||
|
dq = torch.zeros_like(q)
|
||
|
dq_aux = torch.zeros_like(q)
|
||
|
|
||
|
if initial_state is not None:
|
||
|
h += initial_state
|
||
|
|
||
|
for i in range(T):
|
||
|
k_i = k[:, :, i]
|
||
|
v_i = v[:, :, i]
|
||
|
w_i = w[:, :, i].exp()
|
||
|
kv_i = k_i[..., None] * v_i[..., None, :]
|
||
|
h_i = (h + u[None, ..., None] * kv_i)
|
||
|
dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
|
||
|
dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
|
||
|
dq[:, :, i] = dq_i
|
||
|
dq_aux[:, :, i] = dq_aux_i
|
||
|
h = h * w_i[..., None] + kv_i
|
||
|
|
||
|
du = torch.zeros_like(u)
|
||
|
dh = torch.zeros_like(h)
|
||
|
dk = torch.zeros_like(k)
|
||
|
dk_aux = torch.zeros_like(k)
|
||
|
dv = torch.zeros_like(v)
|
||
|
|
||
|
for i in range(T - 1, -1, -1):
|
||
|
d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
|
||
|
k_i = k[:, :, i]
|
||
|
v_i = v[:, :, i]
|
||
|
du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
|
||
|
du += du_i
|
||
|
dk_i = (dh * v_i[..., None, :]).sum(-1)
|
||
|
dk_aux[:, :, i] = dk_i
|
||
|
dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
|
||
|
dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
|
||
|
dv_i += (dh * k_i[..., None]).sum(-2)
|
||
|
|
||
|
dk[:, :, i] = dk_i
|
||
|
dv[:, :, i] = dv_i
|
||
|
dh = dh * w[:, :, i, :, None].exp() + d_kv_i
|
||
|
|
||
|
# dw = q * dq_aux - k * dk_aux
|
||
|
dw = torch.zeros_like(w)
|
||
|
for i in range(T - 2, -1, -1):
|
||
|
dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
|
||
|
|
||
|
return dq, dk, dv, dw, du
|