diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 395642a..3435391 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -18,6 +18,12 @@ try: except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False +try: + from sageattention import sageattn + SAGE_ATTN_AVAILABLE = True +except ModuleNotFoundError: + SAGE_ATTN_AVAILABLE = False + import warnings @@ -127,6 +133,12 @@ def flash_attention( causal=causal, window_size=window_size, deterministic=deterministic).unflatten(0, (b, lq)) + elif SAGE_ATTN_AVAILABLE: + q = q.unsqueeze(0).transpose(1, 2).to(dtype) + k = k.unsqueeze(0).transpose(1, 2).to(dtype) + v = v.unsqueeze(0).transpose(1, 2).to(dtype) + x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal) + x = x.transpose(1, 2).contiguous() else: q = q.unsqueeze(0).transpose(1, 2).to(dtype) k = k.unsqueeze(0).transpose(1, 2).to(dtype) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index ecfc536..4972c26 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -10,6 +10,13 @@ cd DiffSynth-Studio pip install -e . ``` +Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. + +* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) +* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) +* [Sage Attention](https://github.com/thu-ml/SageAttention) +* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.) + ## Inference ### Wan-Video-1.3B-T2V