Merge pull request #537 from CD22104/main

issue523
This commit is contained in:
Zhongjie Duan
2025-04-16 15:53:39 +08:00
committed by GitHub

View File

@@ -36,6 +36,8 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x,tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)