From 1b693d00284d133eb0d236361791e7b382d79997 Mon Sep 17 00:00:00 2001 From: CD22104 <1242884655@qq.com> Date: Wed, 16 Apr 2025 15:49:52 +0800 Subject: [PATCH] issue523 --- diffsynth/models/wan_video_dit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 93d108a..fa2b9f0 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -36,6 +36,8 @@ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x,tuple): + x = x[0] x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) elif FLASH_ATTN_2_AVAILABLE: q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)