This commit is contained in:
9
finetune/lora/v6/fla/ops/hgrn/__init__.py
vendored
Normal file
9
finetune/lora/v6/fla/ops/hgrn/__init__.py
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .chunk import chunk_hgrn
|
||||
from .recurrent_fuse import fused_recurrent_hgrn
|
||||
|
||||
__all__ = [
|
||||
'chunk_hgrn',
|
||||
'fused_recurrent_hgrn'
|
||||
]
|
||||
373
finetune/lora/v6/fla/ops/hgrn/chunk.py
vendored
Normal file
373
finetune/lora/v6/fla/ops/hgrn/chunk.py
vendored
Normal file
@@ -0,0 +1,373 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2024, Yu Zhang, Songlin Yang
|
||||
|
||||
# this function implements the chunkwise form of HGRN, inspired by
|
||||
# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html)
|
||||
# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan
|
||||
|
||||
# from tests on H800, with B, H, D = 16, 4, 128, we see that the chunk can be greatly faster than the recurrent:
|
||||
#
|
||||
# Performance:
|
||||
# seq_len chunk recurrent chunk_bwd recurrent_bwd
|
||||
# 0 128.0 0.039360 0.061056 0.312160 0.205008
|
||||
# 1 256.0 0.045824 0.123712 0.308784 0.297696
|
||||
# 2 512.0 0.058688 0.241952 0.310720 0.626528
|
||||
# 3 1024.0 0.088288 0.476992 0.313184 1.333152
|
||||
# 4 2048.0 0.169472 0.943264 0.452464 2.724864
|
||||
# 5 4096.0 0.329920 1.886144 0.881600 5.551520
|
||||
# 6 8192.0 0.647872 3.755040 1.740496 11.117184
|
||||
# 7 16384.0 1.272064 7.520576 3.446608 22.362528
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BD': 32}, num_warps=1),
|
||||
triton.Config({'BD': 32}, num_warps=2),
|
||||
triton.Config({'BD': 32}, num_warps=4),
|
||||
triton.Config({'BD': 32}, num_warps=8),
|
||||
triton.Config({'BD': 64}, num_warps=1),
|
||||
triton.Config({'BD': 64}, num_warps=2),
|
||||
triton.Config({'BD': 64}, num_warps=4),
|
||||
triton.Config({'BD': 64}, num_warps=8),
|
||||
triton.Config({'BD': 128}, num_warps=1),
|
||||
triton.Config({'BD': 128}, num_warps=2),
|
||||
triton.Config({'BD': 128}, num_warps=4),
|
||||
triton.Config({'BD': 128}, num_warps=8),
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_hgrn_fwd_kernel_h(
|
||||
x,
|
||||
g,
|
||||
gc,
|
||||
o,
|
||||
h0,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr
|
||||
):
|
||||
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
p_x = x + i_bh * T * D + i_t * BT * D + o_d
|
||||
p_g = g + i_bh * T * D + i_t * BT * D + o_d
|
||||
p_gc = gc + i_bh * T * D + i_t * BT * D + o_d
|
||||
p_o = o + i_bh * T * D + i_t * BT * D + o_d
|
||||
|
||||
b_h = tl.zeros([BD], dtype=tl.float32)
|
||||
b_gc = tl.zeros([BD], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if i_t == 0:
|
||||
b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
|
||||
for i in range(0, BT):
|
||||
mask_t = mask & ((i_t * BT + i) < T)
|
||||
b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)
|
||||
b_h = tl.exp(b_g) * b_h + b_x
|
||||
b_gc = b_gc + b_g
|
||||
tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)
|
||||
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)
|
||||
|
||||
p_x += D
|
||||
p_g += D
|
||||
p_gc += D
|
||||
p_o += D
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_hgrn_fwd_kernel_o(
|
||||
gc,
|
||||
o,
|
||||
s_h,
|
||||
s_t,
|
||||
s_d,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr
|
||||
):
|
||||
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
for i_t in range(1, tl.cdiv(T, BT)):
|
||||
p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
|
||||
# [BD,]
|
||||
b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)
|
||||
# [BT, BD]
|
||||
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = b_o + tl.exp(b_gc) * b_h0[None, :]
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BD': 32}, num_warps=1),
|
||||
triton.Config({'BD': 32}, num_warps=2),
|
||||
triton.Config({'BD': 32}, num_warps=4),
|
||||
triton.Config({'BD': 32}, num_warps=8),
|
||||
triton.Config({'BD': 64}, num_warps=1),
|
||||
triton.Config({'BD': 64}, num_warps=2),
|
||||
triton.Config({'BD': 64}, num_warps=4),
|
||||
triton.Config({'BD': 64}, num_warps=8),
|
||||
triton.Config({'BD': 128}, num_warps=1),
|
||||
triton.Config({'BD': 128}, num_warps=2),
|
||||
triton.Config({'BD': 128}, num_warps=4),
|
||||
triton.Config({'BD': 128}, num_warps=8),
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def chunk_hgrn_bwd_kernel_h(
|
||||
g,
|
||||
gc,
|
||||
dx,
|
||||
do,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr
|
||||
):
|
||||
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
BC = min(BT, T - i_t * BT)
|
||||
NT = tl.num_programs(1)
|
||||
|
||||
p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d
|
||||
p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d
|
||||
p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d
|
||||
p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d
|
||||
|
||||
if i_t == NT - 1:
|
||||
b_gc = tl.zeros([BD], dtype=tl.float32)
|
||||
else:
|
||||
b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)
|
||||
b_dh = tl.zeros([BD], dtype=tl.float32)
|
||||
for _ in range(BC - 1, -1, -1):
|
||||
tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)
|
||||
|
||||
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
|
||||
|
||||
b_gc = b_gc + b_g
|
||||
b_dh = b_dh + b_do
|
||||
b_dx = b_dh
|
||||
b_dh = b_dh * tl.exp(b_g)
|
||||
|
||||
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
|
||||
|
||||
p_g -= D
|
||||
p_gc -= D
|
||||
p_dx -= D
|
||||
p_do -= D
|
||||
|
||||
|
||||
@triton.jit
|
||||
def chunk_hgrn_bwd_kernel_o(
|
||||
g,
|
||||
gc,
|
||||
o,
|
||||
dx,
|
||||
dg,
|
||||
s_h,
|
||||
s_t,
|
||||
s_d,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr
|
||||
):
|
||||
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
|
||||
p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))
|
||||
p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
|
||||
|
||||
# [BD,]
|
||||
mask_t = mask & ((i_t + 1) * BT < T)
|
||||
b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)
|
||||
# [BT, BD]
|
||||
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :]
|
||||
b_dg = b_o * b_dx * tl.exp(b_g)
|
||||
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
|
||||
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
class ChunkHGRNFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, x, g, initial_state=None, output_final_state=False):
|
||||
B, H, T, D = x.shape
|
||||
BT, BD = 128, min(64, triton.next_power_of_2(D))
|
||||
num_warps = 8 if BD == 64 else 4
|
||||
|
||||
gc = torch.empty_like(g, dtype=torch.float)
|
||||
o = torch.empty_like(x, dtype=torch.float)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
|
||||
chunk_hgrn_fwd_kernel_h[grid](
|
||||
x, g, gc, o, initial_state,
|
||||
T, D,
|
||||
BT=BT,
|
||||
USE_INITIAL_STATE=initial_state is not None
|
||||
)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
||||
chunk_hgrn_fwd_kernel_o[grid](
|
||||
gc, o,
|
||||
o.stride(1), o.stride(2), o.stride(3),
|
||||
T, D,
|
||||
BT=BT, BD=BD,
|
||||
num_warps=num_warps
|
||||
)
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = o[:, :, -1].clone()
|
||||
o = o.to(x.dtype)
|
||||
ctx.save_for_backward(g, o, initial_state)
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, dht=None):
|
||||
g, o, initial_state = ctx.saved_tensors
|
||||
B, H, T, D = do.shape
|
||||
BT, BD = 128, min(64, triton.next_power_of_2(D))
|
||||
num_warps = 8 if BD == 64 else 4
|
||||
|
||||
gc = torch.empty_like(g, dtype=torch.float)
|
||||
dx = torch.empty_like(o)
|
||||
dg = torch.empty_like(g)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H)
|
||||
chunk_hgrn_bwd_kernel_h[grid](
|
||||
g, gc, dx, do,
|
||||
T, D,
|
||||
BT=BT
|
||||
)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
||||
chunk_hgrn_bwd_kernel_o[grid](
|
||||
g, gc, o, dx, dg,
|
||||
o.stride(1), o.stride(2), o.stride(3),
|
||||
T, D,
|
||||
BT=BT, BD=BD,
|
||||
num_warps=num_warps
|
||||
)
|
||||
if initial_state is not None:
|
||||
dg[:, :, 0] = initial_state * dx[:, :, 0] * g[:, :, 0].exp()
|
||||
|
||||
return dx, dg, None, None
|
||||
|
||||
|
||||
def chunk_hgrn(
|
||||
x: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fla.ops.hgrn.naive import naive_recurrent_hgrn
|
||||
from fla.ops.hgrn.recurrent_fuse import fused_recurrent_hgrn
|
||||
B, H, T, D = 8, 4, 512, 128
|
||||
dtype = torch.bfloat16
|
||||
torch.manual_seed(42)
|
||||
# [batch_size, n_heads, seq_len, d_head]
|
||||
x = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
|
||||
g = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
|
||||
x, g = (1 - g.sigmoid()) * x, F.logsigmoid(g)
|
||||
print(f'x:\t{float(x.min()):>10.6f}\t{float(x.max()):>10.6f}')
|
||||
print(f'g:\t{float(g.min()):>10.6f}\t{float(g.max()):>10.6f}')
|
||||
x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
|
||||
print(f"DTYPE:\t{x.dtype}")
|
||||
do = torch.randn_like(x)
|
||||
h0 = torch.randn_like(x[:, :, 0])
|
||||
ref, ref_ht = naive_recurrent_hgrn(x, g, h0, output_final_state=True)
|
||||
ref.backward(do)
|
||||
ref_dx, x.grad = x.grad.clone(), None
|
||||
ref_dg, g.grad = g.grad.clone(), None
|
||||
|
||||
tri, tri_ht = fused_recurrent_hgrn(x, g, h0, output_final_state=True)
|
||||
tri.backward(do)
|
||||
tri_dx, x.grad = x.grad.clone(), None
|
||||
tri_dg, g.grad = g.grad.clone(), None
|
||||
print(" \t DIFF\t MAX")
|
||||
print(' o\t', f"{float((ref - tri).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
|
||||
print('ht\t', f"{float((ref_ht[0] - tri_ht[0]).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}")
|
||||
print('dx\t', f"{float((ref_dx - tri_dx).abs().max()):>10.6f}\t{float(ref_dx.max()):>10.6f}")
|
||||
print('dg\t', f"{float((ref_dg - tri_dg).abs().max()):>10.6f}\t{float(ref_dg.max()):>10.6f}")
|
||||
print('Done!')
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
# argument names to use as an x-axis for the plot
|
||||
x_names=['seq_len'],
|
||||
# different possible values for `x_name`
|
||||
x_vals=[128 * 2 ** i for i in range(0, 8)],
|
||||
# argument name whose value corresponds to a different line in the plot
|
||||
line_arg='provider',
|
||||
# possible values for `line_arg``
|
||||
line_vals=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
|
||||
# label name for the lines
|
||||
line_names=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'],
|
||||
# line styles
|
||||
styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')],
|
||||
ylabel="Execution Time (ms)", # label name for the y-axis
|
||||
# name for the plot. Used also as a file name for saving the plot.
|
||||
plot_name="Performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
B, H, D = 16, 4, 128
|
||||
|
||||
x = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda')
|
||||
g = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda').sigmoid()
|
||||
x = (1 - g) * x
|
||||
x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g))
|
||||
do = torch.randn_like(x, dtype=dtype)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
results = 0, 0, 0
|
||||
if provider == 'chunk':
|
||||
results = triton.testing.do_bench(lambda: chunk_hgrn(x, g), quantiles=quantiles)
|
||||
if provider == 'recurrent':
|
||||
results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g), quantiles=quantiles)
|
||||
if provider == 'chunk_bwd':
|
||||
results = triton.testing.do_bench(lambda: chunk_hgrn(x, g)[0].backward(do), quantiles=quantiles)
|
||||
if provider == 'recurrent_bwd':
|
||||
results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g)[0].backward(do), quantiles=quantiles)
|
||||
return results
|
||||
benchmark.run(print_data=True)
|
||||
31
finetune/lora/v6/fla/ops/hgrn/naive.py
vendored
Normal file
31
finetune/lora/v6/fla/ops/hgrn/naive.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def naive_recurrent_hgrn(
|
||||
x: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: Optional[bool] = False
|
||||
) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
x, g = map(lambda i: i.float(), (x, g))
|
||||
B, H, T, D = x.shape
|
||||
|
||||
h = torch.zeros(B, H, D, dtype=torch.float, device=x.device)
|
||||
o = torch.zeros_like(x)
|
||||
|
||||
final_state = None
|
||||
if initial_state is not None:
|
||||
h += initial_state.detach()
|
||||
|
||||
for i in range(T):
|
||||
h = g[:, :, i].exp() * h + x[:, :, i]
|
||||
o[:, :, i] = h
|
||||
|
||||
if output_final_state:
|
||||
final_state = h
|
||||
return o.to(dtype), final_state
|
||||
185
finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py
vendored
Normal file
185
finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py
vendored
Normal file
@@ -0,0 +1,185 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (c) 2023, Songlin Yang
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fla.utils import contiguous
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BD': 32}, num_warps=1),
|
||||
triton.Config({'BD': 32}, num_warps=2),
|
||||
triton.Config({'BD': 32}, num_warps=4),
|
||||
triton.Config({'BD': 32}, num_warps=8),
|
||||
triton.Config({'BD': 64}, num_warps=1),
|
||||
triton.Config({'BD': 64}, num_warps=2),
|
||||
triton.Config({'BD': 64}, num_warps=4),
|
||||
triton.Config({'BD': 64}, num_warps=8),
|
||||
triton.Config({'BD': 128}, num_warps=1),
|
||||
triton.Config({'BD': 128}, num_warps=2),
|
||||
triton.Config({'BD': 128}, num_warps=4),
|
||||
triton.Config({'BD': 128}, num_warps=8),
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def fused_recurrent_hgrn_fwd_kernel(
|
||||
x,
|
||||
g,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr
|
||||
):
|
||||
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
p_x = x + i_bh * T * D + o_d
|
||||
p_g = g + i_bh * T * D + o_d
|
||||
p_o = o + i_bh * T * D + o_d
|
||||
|
||||
b_h = tl.zeros([BD], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = h0 + i_bh * D + o_d
|
||||
b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
|
||||
for _ in range(0, T):
|
||||
b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
b_h = tl.exp(b_g) * b_h + b_x
|
||||
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
|
||||
|
||||
p_x += D
|
||||
p_g += D
|
||||
p_o += D
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = ht + i_bh * D + o_d
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BD': 32}, num_warps=1),
|
||||
triton.Config({'BD': 32}, num_warps=2),
|
||||
triton.Config({'BD': 32}, num_warps=4),
|
||||
triton.Config({'BD': 32}, num_warps=8),
|
||||
triton.Config({'BD': 64}, num_warps=1),
|
||||
triton.Config({'BD': 64}, num_warps=2),
|
||||
triton.Config({'BD': 64}, num_warps=4),
|
||||
triton.Config({'BD': 64}, num_warps=8),
|
||||
triton.Config({'BD': 128}, num_warps=1),
|
||||
triton.Config({'BD': 128}, num_warps=2),
|
||||
triton.Config({'BD': 128}, num_warps=4),
|
||||
triton.Config({'BD': 128}, num_warps=8),
|
||||
],
|
||||
key=['D']
|
||||
)
|
||||
@triton.jit
|
||||
def fused_recurrent_hgrn_bwd_kernel(
|
||||
g,
|
||||
o,
|
||||
dx,
|
||||
dg,
|
||||
do,
|
||||
h0,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr
|
||||
):
|
||||
i_d, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
o_d = i_d * BD + tl.arange(0, BD)
|
||||
mask = o_d < D
|
||||
|
||||
p_g = g + (i_bh * T + T - 1) * D + o_d
|
||||
p_o = o + (i_bh * T + T - 2) * D + o_d
|
||||
p_dx = dx + (i_bh * T + T - 1) * D + o_d
|
||||
p_dg = dg + (i_bh * T + T - 1) * D + o_d
|
||||
p_do = do + (i_bh * T + T - 1) * D + o_d
|
||||
|
||||
b_dh = tl.zeros([BD], dtype=tl.float32)
|
||||
for i in range(T - 1, -1, -1):
|
||||
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
|
||||
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
|
||||
if i > 0:
|
||||
b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
|
||||
elif USE_INITIAL_STATE:
|
||||
b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)
|
||||
else:
|
||||
b_o = tl.zeros([BD], dtype=tl.float32)
|
||||
|
||||
b_dh = b_dh + b_do
|
||||
b_dx = b_dh
|
||||
b_dh = b_dh * tl.exp(b_g)
|
||||
b_dg = b_dh * b_o
|
||||
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
|
||||
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
|
||||
|
||||
p_g -= D
|
||||
p_o -= D
|
||||
p_dx -= D
|
||||
p_dg -= D
|
||||
p_do -= D
|
||||
|
||||
|
||||
class FusedRecurrentHGRNFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def forward(ctx, x, g, initial_state=None, output_final_state=False):
|
||||
B, H, T, D = x.shape
|
||||
|
||||
final_state = None
|
||||
if output_final_state:
|
||||
final_state = x.new_empty(B, H, D)
|
||||
|
||||
o = torch.empty_like(x)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
||||
fused_recurrent_hgrn_fwd_kernel[grid](
|
||||
x, g, o, initial_state, final_state,
|
||||
T, D,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
STORE_FINAL_STATE=final_state is not None
|
||||
)
|
||||
ctx.save_for_backward(g, o, initial_state)
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@contiguous
|
||||
def backward(ctx, do, dht=None):
|
||||
g, o, initial_state = ctx.saved_tensors
|
||||
B, H, T, D = do.shape
|
||||
|
||||
dx = torch.empty_like(o)
|
||||
dg = torch.empty_like(g)
|
||||
def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)
|
||||
fused_recurrent_hgrn_bwd_kernel[grid](
|
||||
g, o, dx, dg, do, initial_state,
|
||||
T, D,
|
||||
USE_INITIAL_STATE=initial_state is not None,
|
||||
)
|
||||
|
||||
return dx, dg, None, None
|
||||
|
||||
|
||||
def fused_recurrent_hgrn(
|
||||
x: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if initial_state is not None:
|
||||
initial_state = initial_state.detach()
|
||||
o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)
|
||||
return o, final_state
|
||||
Reference in New Issue
Block a user