diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index f1e5e47..32a79e3 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -24,8 +24,14 @@ except ModuleNotFoundError: SAGE_ATTN_AVAILABLE = False -def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int): - if FLASH_ATTN_3_AVAILABLE: +def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False): + if compatibility_mode: + q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) + k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) + v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + elif FLASH_ATTN_3_AVAILABLE: q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) diff --git a/diffsynth/models/wan_video_image_encoder.py b/diffsynth/models/wan_video_image_encoder.py index b49235b..5ca878b 100644 --- a/diffsynth/models/wan_video_image_encoder.py +++ b/diffsynth/models/wan_video_image_encoder.py @@ -260,7 +260,7 @@ class SelfAttention(nn.Module): q, k, v = self.to_qkv(x).chunk(3, dim=-1) # compute attention - x = flash_attention(q, k, v, num_heads=self.num_heads) + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) # output x = self.proj(x) @@ -371,7 +371,7 @@ class AttentionPool(nn.Module): k, v = self.to_kv(x).chunk(2, dim=-1) # compute attention - x = flash_attention(q, k, v, num_heads=self.num_heads) + x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True) x = x.reshape(b, 1, c) # output