253 lines
9.1 KiB
Python
Vendored
253 lines
9.1 KiB
Python
Vendored
# Copyright (c) 2023, Tri Dao. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py
|
|
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
# @triton.autotune(
|
|
# configs=[
|
|
# triton.Config({"BLOCK_M": 2}),
|
|
# triton.Config({"BLOCK_M": 4}),
|
|
# triton.Config({"BLOCK_M": 8}),
|
|
# triton.Config({"BLOCK_M": 16}),
|
|
# ],
|
|
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"],
|
|
# )
|
|
@triton.jit
|
|
def rotary_kernel(
|
|
OUT, # Pointers to matrices
|
|
X,
|
|
COS,
|
|
SIN,
|
|
CU_SEQLENS,
|
|
SEQLEN_OFFSETS, # this could be int or a pointer
|
|
# Matrix dimensions
|
|
seqlen,
|
|
nheads,
|
|
rotary_dim,
|
|
seqlen_ro,
|
|
CACHE_KEY_SEQLEN,
|
|
# strides
|
|
stride_out_batch,
|
|
stride_out_seqlen,
|
|
stride_out_nheads,
|
|
stride_out_headdim,
|
|
stride_x_batch,
|
|
stride_x_seqlen,
|
|
stride_x_nheads,
|
|
stride_x_headdim,
|
|
# Meta-parameters
|
|
BLOCK_K: tl.constexpr,
|
|
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
INTERLEAVED: tl.constexpr,
|
|
CONJUGATE: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
):
|
|
pid_m = tl.program_id(axis=0)
|
|
pid_batch = tl.program_id(axis=1)
|
|
pid_head = tl.program_id(axis=2)
|
|
rotary_dim_half = rotary_dim // 2
|
|
|
|
if not IS_VARLEN:
|
|
X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
|
|
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
|
|
else:
|
|
start_idx = tl.load(CU_SEQLENS + pid_batch)
|
|
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
|
|
X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
|
|
OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
|
|
|
|
if pid_m * BLOCK_M >= seqlen:
|
|
return
|
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
if not IS_SEQLEN_OFFSETS_TENSOR:
|
|
rm_cs = rm + SEQLEN_OFFSETS
|
|
else:
|
|
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
|
|
rk = tl.arange(0, BLOCK_K)
|
|
rk_half = tl.arange(0, BLOCK_K // 2)
|
|
|
|
if not INTERLEAVED:
|
|
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
|
|
X = X + (rm[:, None] * stride_x_seqlen +
|
|
rk_half[None, :] * stride_x_headdim)
|
|
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
|
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
|
|
cos = tl.load(
|
|
COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
|
|
).to(tl.float32)
|
|
sin = tl.load(
|
|
SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
|
).to(tl.float32)
|
|
x0 = tl.load(
|
|
X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
|
|
).to(tl.float32)
|
|
x1 = tl.load(
|
|
X + rotary_dim_half * stride_x_headdim,
|
|
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
|
other=0.0,
|
|
).to(tl.float32)
|
|
if CONJUGATE:
|
|
sin = -sin
|
|
o0 = x0 * cos - x1 * sin
|
|
o1 = x0 * sin + x1 * cos
|
|
# write back result
|
|
OUT = OUT + (rm[:, None] * stride_out_seqlen +
|
|
rk_half[None, :] * stride_out_headdim)
|
|
tl.store(OUT, o0, mask=(rm[:, None] < seqlen)
|
|
& (rk_half[None, :] < rotary_dim_half))
|
|
tl.store(
|
|
OUT + rotary_dim_half * stride_out_headdim,
|
|
o1,
|
|
mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
|
|
)
|
|
else:
|
|
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
|
|
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
|
|
# Loading x0 will be fast but x1 will be slow.
|
|
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
|
|
# Then we do the calculation and use tl.where to pick put the right outputs for the even
|
|
# and for the odd indices.
|
|
rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
|
|
rk_repeat = tl.arange(0, BLOCK_K) // 2
|
|
X0 = X + (rm[:, None] * stride_x_seqlen +
|
|
rk[None, :] * stride_x_headdim)
|
|
X1 = X + (rm[:, None] * stride_x_seqlen +
|
|
rk_swap[None, :] * stride_x_headdim)
|
|
COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
|
SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
|
|
cos = tl.load(
|
|
COS,
|
|
mask=(rm_cs[:, None] < seqlen_ro) & (
|
|
rk_repeat[None, :] < rotary_dim_half),
|
|
other=1.0,
|
|
).to(tl.float32)
|
|
sin = tl.load(
|
|
SIN,
|
|
mask=(rm_cs[:, None] < seqlen_ro) & (
|
|
rk_repeat[None, :] < rotary_dim_half),
|
|
other=0.0,
|
|
).to(tl.float32)
|
|
x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
|
|
tl.float32
|
|
)
|
|
x1 = tl.load(
|
|
X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
|
|
).to(tl.float32)
|
|
if CONJUGATE:
|
|
sin = -sin
|
|
x0_cos = x0 * cos
|
|
x1_sin = x1 * sin
|
|
out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
|
|
OUT = OUT + (rm[:, None] * stride_out_seqlen +
|
|
rk[None, :] * stride_out_headdim)
|
|
tl.store(OUT, out, mask=(rm[:, None] < seqlen)
|
|
& (rk[None, :] < rotary_dim))
|
|
|
|
|
|
def apply_rotary(
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
seqlen_offsets: Union[int, torch.Tensor] = 0,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
max_seqlen: Optional[int] = None,
|
|
interleaved=False,
|
|
inplace=False,
|
|
conjugate=False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Arguments:
|
|
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
|
|
else (total_seqlen, nheads, headdim).
|
|
cos: (seqlen_ro, rotary_dim / 2)
|
|
sin: (seqlen_ro, rotary_dim / 2)
|
|
seqlen_offsets: integer or integer tensor of size (batch,)
|
|
cu_seqlens: (batch + 1,) or None
|
|
max_seqlen: int
|
|
Returns:
|
|
y: (batch, seqlen, nheads, headdim)
|
|
"""
|
|
is_varlen = cu_seqlens is not None
|
|
if not is_varlen:
|
|
batch, seqlen, nheads, headdim = x.shape
|
|
else:
|
|
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
|
|
total_seqlen, nheads, headdim = x.shape
|
|
batch_p_1 = cu_seqlens.shape[0]
|
|
batch = batch_p_1 - 1
|
|
seqlen = max_seqlen
|
|
seqlen_ro, rotary_dim = cos.shape
|
|
assert sin.shape == cos.shape
|
|
rotary_dim *= 2
|
|
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
|
|
assert headdim <= 256, "Only support headdim <= 256"
|
|
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
|
|
|
|
assert (
|
|
cos.dtype == sin.dtype
|
|
), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
|
|
assert (
|
|
x.dtype == cos.dtype
|
|
), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
|
|
|
|
cos, sin = cos.contiguous(), sin.contiguous()
|
|
if isinstance(seqlen_offsets, torch.Tensor):
|
|
assert seqlen_offsets.shape == (batch,)
|
|
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
|
|
seqlen_offsets = seqlen_offsets.contiguous()
|
|
else:
|
|
assert seqlen_offsets + seqlen <= seqlen_ro
|
|
|
|
output = torch.empty_like(x) if not inplace else x
|
|
if rotary_dim < headdim and not inplace:
|
|
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
|
|
|
BLOCK_K = (
|
|
32
|
|
if rotary_dim <= 32
|
|
else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
|
|
)
|
|
def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
|
|
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
|
|
|
|
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
|
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
|
with torch.cuda.device(x.device.index):
|
|
rotary_kernel[grid](
|
|
output, # data ptrs
|
|
x,
|
|
cos,
|
|
sin,
|
|
cu_seqlens,
|
|
seqlen_offsets,
|
|
seqlen, # shapes
|
|
nheads,
|
|
rotary_dim,
|
|
seqlen_ro,
|
|
# key for triton cache (limit number of compilations)
|
|
seqlen // 128,
|
|
# batch_strides if not varlen else 0
|
|
output.stride(0) if not is_varlen else 0,
|
|
output.stride(-3), # seqlen_stride or total_seqlen_stride
|
|
output.stride(-2), # nheads_stride
|
|
output.stride(-1), # headdim_stride
|
|
# batch_strides if not varlen else 0
|
|
x.stride(0) if not is_varlen else 0,
|
|
x.stride(-3), # seqlen stride or total_seqlen_stride
|
|
x.stride(-2), # nheads stride
|
|
x.stride(-1), # headdim stride
|
|
BLOCK_K,
|
|
isinstance(seqlen_offsets, torch.Tensor),
|
|
is_varlen,
|
|
interleaved,
|
|
conjugate,
|
|
BLOCK_M,
|
|
)
|
|
return output
|