This commit is contained in:
102
finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py
vendored
Normal file
102
finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def naive_recurrent_rwkv6(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
scale: Optional[float] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: Optional[bool] = False
|
||||
):
|
||||
orig_dtype = q.dtype
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
|
||||
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
|
||||
o = torch.zeros_like(v)
|
||||
|
||||
if scale is None:
|
||||
scale = K ** -0.5
|
||||
|
||||
if initial_state is not None:
|
||||
h += initial_state
|
||||
|
||||
for i in range(T):
|
||||
q_i = q[:, :, i, :] * scale
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i, :]
|
||||
w_i = w[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
|
||||
o[:, :, i] = o_i.sum(-2)
|
||||
h = h * w_i[..., None] + kv_i
|
||||
ht = h if output_final_state else None
|
||||
return o.to(orig_dtype), ht
|
||||
|
||||
|
||||
def naive_recurrent_rwkv6_bwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
u,
|
||||
o,
|
||||
do,
|
||||
initial_state=None,
|
||||
output_final_state=False
|
||||
):
|
||||
q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do))
|
||||
B, H, T, K, V = *q.shape, v.shape[-1]
|
||||
h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device)
|
||||
dq = torch.zeros_like(q)
|
||||
dq_aux = torch.zeros_like(q)
|
||||
|
||||
if initial_state is not None:
|
||||
h += initial_state
|
||||
|
||||
for i in range(T):
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i]
|
||||
w_i = w[:, :, i].exp()
|
||||
kv_i = k_i[..., None] * v_i[..., None, :]
|
||||
h_i = (h + u[None, ..., None] * kv_i)
|
||||
dq_i = (do[:, :, i, None, :] * h_i).sum(-1)
|
||||
dq_aux_i = (do[:, :, i, None, :] * h).sum(-1)
|
||||
dq[:, :, i] = dq_i
|
||||
dq_aux[:, :, i] = dq_aux_i
|
||||
h = h * w_i[..., None] + kv_i
|
||||
|
||||
du = torch.zeros_like(u)
|
||||
dh = torch.zeros_like(h)
|
||||
dk = torch.zeros_like(k)
|
||||
dk_aux = torch.zeros_like(k)
|
||||
dv = torch.zeros_like(v)
|
||||
|
||||
for i in range(T - 1, -1, -1):
|
||||
d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None]
|
||||
k_i = k[:, :, i]
|
||||
v_i = v[:, :, i]
|
||||
du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1)
|
||||
du += du_i
|
||||
dk_i = (dh * v_i[..., None, :]).sum(-1)
|
||||
dk_aux[:, :, i] = dk_i
|
||||
dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1)
|
||||
dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2)
|
||||
dv_i += (dh * k_i[..., None]).sum(-2)
|
||||
|
||||
dk[:, :, i] = dk_i
|
||||
dv[:, :, i] = dv_i
|
||||
dh = dh * w[:, :, i, :, None].exp() + d_kv_i
|
||||
|
||||
# dw = q * dq_aux - k * dk_aux
|
||||
dw = torch.zeros_like(w)
|
||||
for i in range(T - 2, -1, -1):
|
||||
dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i]
|
||||
|
||||
return dq, dk, dv, dw, du
|
||||
Reference in New Issue
Block a user