Merge pull request #392 from modelscope/sage_attention

Sage attention
This commit is contained in:
Zhongjie Duan
2025-03-03 14:28:36 +08:00
committed by GitHub
2 changed files with 19 additions and 0 deletions

View File

@@ -18,6 +18,12 @@ try:
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
import warnings
@@ -127,6 +133,12 @@ def flash_attention(
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
elif SAGE_ATTN_AVAILABLE:
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
v = v.unsqueeze(0).transpose(1, 2).to(dtype)
x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
x = x.transpose(1, 2).contiguous()
else:
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
k = k.unsqueeze(0).transpose(1, 2).to(dtype)

View File

@@ -10,6 +10,13 @@ cd DiffSynth-Studio
pip install -e .
```
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
* [Sage Attention](https://github.com/thu-ml/SageAttention)
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
## Inference
### Wan-Video-1.3B-T2V