mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
122 lines
5.6 KiB
Python
122 lines
5.6 KiB
Python
import torch, os
|
|
from einops import rearrange
|
|
|
|
|
|
try:
|
|
import flash_attn_interface
|
|
FLASH_ATTN_3_AVAILABLE = True
|
|
except ModuleNotFoundError:
|
|
FLASH_ATTN_3_AVAILABLE = False
|
|
|
|
try:
|
|
import flash_attn
|
|
FLASH_ATTN_2_AVAILABLE = True
|
|
except ModuleNotFoundError:
|
|
FLASH_ATTN_2_AVAILABLE = False
|
|
|
|
try:
|
|
from sageattention import sageattn
|
|
SAGE_ATTN_AVAILABLE = True
|
|
except ModuleNotFoundError:
|
|
SAGE_ATTN_AVAILABLE = False
|
|
|
|
try:
|
|
import xformers.ops as xops
|
|
XFORMERS_AVAILABLE = True
|
|
except ModuleNotFoundError:
|
|
XFORMERS_AVAILABLE = False
|
|
|
|
|
|
def initialize_attention_priority():
|
|
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
|
|
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
|
|
elif FLASH_ATTN_3_AVAILABLE:
|
|
return "flash_attention_3"
|
|
elif FLASH_ATTN_2_AVAILABLE:
|
|
return "flash_attention_2"
|
|
elif SAGE_ATTN_AVAILABLE:
|
|
return "sage_attention"
|
|
elif XFORMERS_AVAILABLE:
|
|
return "xformers"
|
|
else:
|
|
return "torch"
|
|
|
|
|
|
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
|
|
|
|
|
|
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
|
|
dims = {} if dims is None else dims
|
|
if q_pattern != required_in_pattern:
|
|
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
|
|
if k_pattern != required_in_pattern:
|
|
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
|
if v_pattern != required_in_pattern:
|
|
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
|
return q, k, v
|
|
|
|
|
|
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
|
|
dims = {} if dims is None else dims
|
|
if out_pattern != required_out_pattern:
|
|
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
|
|
return out
|
|
|
|
|
|
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
|
return out
|
|
|
|
|
|
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
|
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
|
|
if isinstance(out, tuple):
|
|
out = out[0]
|
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
|
return out
|
|
|
|
|
|
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
|
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
|
return out
|
|
|
|
|
|
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
|
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
|
out = sageattn(q, k, v, sm_scale=scale)
|
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
|
return out
|
|
|
|
|
|
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
|
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
|
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
|
out = xops.memory_efficient_attention(q, k, v, scale=scale)
|
|
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
|
return out
|
|
|
|
|
|
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
|
if compatibility_mode or (attn_mask is not None):
|
|
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
|
else:
|
|
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
|
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
|
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
|
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
elif ATTENTION_IMPLEMENTATION == "xformers":
|
|
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
|
else:
|
|
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|