This commit is contained in:
Artiprocher
2025-03-10 18:25:23 +08:00
parent e757013a14
commit 718b45f2af
2 changed files with 10 additions and 4 deletions

View File

@@ -260,7 +260,7 @@ class SelfAttention(nn.Module):
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# compute attention
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
# output
x = self.proj(x)
@@ -371,7 +371,7 @@ class AttentionPool(nn.Module):
k, v = self.to_kv(x).chunk(2, dim=-1)
# compute attention
x = flash_attention(q, k, v, num_heads=self.num_heads)
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
x = x.reshape(b, 1, c)
# output