mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
bugfix
This commit is contained in:
@@ -24,8 +24,14 @@ except ModuleNotFoundError:
|
|||||||
SAGE_ATTN_AVAILABLE = False
|
SAGE_ATTN_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int):
|
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
|
||||||
if FLASH_ATTN_3_AVAILABLE:
|
if compatibility_mode:
|
||||||
|
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||||
|
elif FLASH_ATTN_3_AVAILABLE:
|
||||||
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
||||||
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
||||||
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ class SelfAttention(nn.Module):
|
|||||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
@@ -371,7 +371,7 @@ class AttentionPool(nn.Module):
|
|||||||
k, v = self.to_kv(x).chunk(2, dim=-1)
|
k, v = self.to_kv(x).chunk(2, dim=-1)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
|
||||||
x = x.reshape(b, 1, c)
|
x = x.reshape(b, 1, c)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
|
|||||||
Reference in New Issue
Block a user