mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
@@ -18,6 +18,12 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGE_ATTN_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
SAGE_ATTN_AVAILABLE = False
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
@@ -127,6 +133,12 @@ def flash_attention(
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
deterministic=deterministic).unflatten(0, (b, lq))
|
||||
elif SAGE_ATTN_AVAILABLE:
|
||||
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||
v = v.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||
x = sageattn(q, k, v, dropout_p=dropout_p, is_causal=causal)
|
||||
x = x.transpose(1, 2).contiguous()
|
||||
else:
|
||||
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||
|
||||
Reference in New Issue
Block a user