186 lines
5.3 KiB
Python
Vendored
186 lines
5.3 KiB
Python
Vendored
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright (c) 2023, Songlin Yang
|
|
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from fla.utils import contiguous
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({'BD': 32}, num_warps=1),
|
|
triton.Config({'BD': 32}, num_warps=2),
|
|
triton.Config({'BD': 32}, num_warps=4),
|
|
triton.Config({'BD': 32}, num_warps=8),
|
|
triton.Config({'BD': 64}, num_warps=1),
|
|
triton.Config({'BD': 64}, num_warps=2),
|
|
triton.Config({'BD': 64}, num_warps=4),
|
|
triton.Config({'BD': 64}, num_warps=8),
|
|
triton.Config({'BD': 128}, num_warps=1),
|
|
triton.Config({'BD': 128}, num_warps=2),
|
|
triton.Config({'BD': 128}, num_warps=4),
|
|
triton.Config({'BD': 128}, num_warps=8),
|
|
],
|
|
key=['D']
|
|
)
|
|
@triton.jit
|
|
def fused_recurrent_hgrn_fwd_kernel(
|
|
x,
|
|
g,
|
|
o,
|
|
h0,
|
|
ht,
|
|
T: tl.constexpr,
|
|
D: tl.constexpr,
|
|
BD: tl.constexpr,
|
|
USE_INITIAL_STATE: tl.constexpr,
|
|
STORE_FINAL_STATE: tl.constexpr
|
|
):
|
|
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
|
o_d = i_d * BD + tl.arange(0, BD)
|
|
mask = o_d < D
|
|
|
|
p_x = x + i_bh * T * D + o_d
|
|
p_g = g + i_bh * T * D + o_d
|
|
p_o = o + i_bh * T * D + o_d
|
|
|
|
b_h = tl.zeros([BD], dtype=tl.float32)
|
|
if USE_INITIAL_STATE:
|
|
p_h0 = h0 + i_bh * D + o_d
|
|
b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
|
|
for _ in range(0, T):
|
|
b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
|
|
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
|
b_h = tl.exp(b_g) * b_h + b_x
|
|
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
|
|
|
|
p_x += D
|
|
p_g += D
|
|
p_o += D
|
|
|
|
if STORE_FINAL_STATE:
|
|
p_ht = ht + i_bh * D + o_d
|
|
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({'BD': 32}, num_warps=1),
|
|
triton.Config({'BD': 32}, num_warps=2),
|
|
triton.Config({'BD': 32}, num_warps=4),
|
|
triton.Config({'BD': 32}, num_warps=8),
|
|
triton.Config({'BD': 64}, num_warps=1),
|
|
triton.Config({'BD': 64}, num_warps=2),
|
|
triton.Config({'BD': 64}, num_warps=4),
|
|
triton.Config({'BD': 64}, num_warps=8),
|
|
triton.Config({'BD': 128}, num_warps=1),
|
|
triton.Config({'BD': 128}, num_warps=2),
|
|
triton.Config({'BD': 128}, num_warps=4),
|
|
triton.Config({'BD': 128}, num_warps=8),
|
|
],
|
|
key=['D']
|
|
)
|
|
@triton.jit
|
|
def fused_recurrent_hgrn_bwd_kernel(
|
|
g,
|
|
o,
|
|
dx,
|
|
dg,
|
|
do,
|
|
h0,
|
|
T: tl.constexpr,
|
|
D: tl.constexpr,
|
|
BD: tl.constexpr,
|
|
USE_INITIAL_STATE: tl.constexpr
|
|
):
|
|
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
|
o_d = i_d * BD + tl.arange(0, BD)
|
|
mask = o_d < D
|
|
|
|
p_g = g + (i_bh * T + T - 1) * D + o_d
|
|
p_o = o + (i_bh * T + T - 2) * D + o_d
|
|
p_dx = dx + (i_bh * T + T - 1) * D + o_d
|
|
p_dg = dg + (i_bh * T + T - 1) * D + o_d
|
|
p_do = do + (i_bh * T + T - 1) * D + o_d
|
|
|
|
b_dh = tl.zeros([BD], dtype=tl.float32)
|
|
for i in range(T - 1, -1, -1):
|
|
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
|
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
|
|
if i > 0:
|
|
b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
|
|
elif USE_INITIAL_STATE:
|
|
b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
|
|
else:
|
|
b_o = tl.zeros([BD], dtype=tl.float32)
|
|
|
|
b_dh = b_dh + b_do
|
|
b_dx = b_dh
|
|
b_dh = b_dh * tl.exp(b_g)
|
|
b_dg = b_dh * b_o
|
|
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
|
|
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
|
|
|
|
p_g -= D
|
|
p_o -= D
|
|
p_dx -= D
|
|
p_dg -= D
|
|
p_do -= D
|
|
|
|
|
|
class FusedRecurrentHGRNFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
@contiguous
|
|
def forward(ctx, x, g, initial_state=None, output_final_state=False):
|
|
B, H, T, D = x.shape
|
|
|
|
final_state = None
|
|
if output_final_state:
|
|
final_state = x.new_empty(B, H, D)
|
|
|
|
o = torch.empty_like(x)
|
|
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
|
fused_recurrent_hgrn_fwd_kernel[grid](
|
|
x, g, o, initial_state, final_state,
|
|
T, D,
|
|
USE_INITIAL_STATE=initial_state is not None,
|
|
STORE_FINAL_STATE=final_state is not None
|
|
)
|
|
ctx.save_for_backward(g, o, initial_state)
|
|
return o, final_state
|
|
|
|
@staticmethod
|
|
@contiguous
|
|
def backward(ctx, do, dht=None):
|
|
g, o, initial_state = ctx.saved_tensors
|
|
B, H, T, D = do.shape
|
|
|
|
dx = torch.empty_like(o)
|
|
dg = torch.empty_like(g)
|
|
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
|
fused_recurrent_hgrn_bwd_kernel[grid](
|
|
g, o, dx, dg, do, initial_state,
|
|
T, D,
|
|
USE_INITIAL_STATE=initial_state is not None,
|
|
)
|
|
|
|
return dx, dg, None, None
|
|
|
|
|
|
def fused_recurrent_hgrn(
|
|
x: torch.Tensor,
|
|
g: torch.Tensor,
|
|
initial_state: torch.Tensor = None,
|
|
output_final_state: bool = False
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if initial_state is not None:
|
|
initial_state = initial_state.detach()
|
|
o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)
|
|
return o, final_state
|