From 852c3d831fe6b2d5ddbe74eb1edf1c4d4e75e53c Mon Sep 17 00:00:00 2001 From: philipy1219 <602203830@qq.com> Date: Sun, 2 Mar 2025 15:09:21 +0800 Subject: [PATCH 1/2] support sageattn --- diffsynth/models/wan_video_dit.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 395642a..3435391 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -18,6 +18,12 @@ try: except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + import warnings @@ -127,6 +133,12 @@ def flash_attention( causal=causal, window_size=window_size, deterministic=deterministic).unflatten(0, (b, lq)) + elif SAGE_ATTN_AVAILABLE: + q = q.unsqueeze(0).transpose(1, 2).to(dtype) + k = k.unsqueeze(0).transpose(1, 2).to(dtype) + v = v.unsqueeze(0).transpose(1, 2).to(dtype) + x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal) + x = x.transpose(1, 2).contiguous() else: q = q.unsqueeze(0).transpose(1, 2).to(dtype) k = k.unsqueeze(0).transpose(1, 2).to(dtype) From da8e1fe7e4d87cbe203bc67b6b960857abe110d4 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 3 Mar 2025 14:19:16 +0800 Subject: [PATCH 2/2] support sage attention --- examples/wanvideo/README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index ecfc536..4972c26 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -10,6 +10,13 @@ cd DiffSynth-Studio pip install -e . ``` +Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. + +* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) +* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) +* [Sage Attention](https://github.com/thu-ml/SageAttention) +* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.) + ## Inference ### Wan-Video-1.3B-T2V