diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 395642a..3435391 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -18,6 +18,12 @@ try: except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + import warnings @@ -127,6 +133,12 @@ def flash_attention( causal=causal, window_size=window_size, deterministic=deterministic).unflatten(0, (b, lq)) + elif SAGE_ATTN_AVAILABLE: + q = q.unsqueeze(0).transpose(1, 2).to(dtype) + k = k.unsqueeze(0).transpose(1, 2).to(dtype) + v = v.unsqueeze(0).transpose(1, 2).to(dtype) + x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal) + x = x.transpose(1, 2).contiguous() else: q = q.unsqueeze(0).transpose(1, 2).to(dtype) k = k.unsqueeze(0).transpose(1, 2).to(dtype)