support qwen-image-fp8

This commit is contained in:
Artiprocher
2025-08-07 16:20:50 +08:00
parent d20680baae
commit a0eec8c673
7 changed files with 224 additions and 40 deletions

View File

@@ -1,10 +1,44 @@
import torch
import torch, math
import torch.nn as nn
from typing import Tuple, Optional, Union, List
from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm
from .flux_dit import AdaLayerNorm
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
origin_dtype = q.dtype
q_std, k_std, v_std = q.std(), k.std(), v.std()
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
if isinstance(x, tuple):
x = x[0]
x = x.to(origin_dtype) * v_std
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
class ApproximateGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
@@ -158,7 +192,8 @@ class QwenDoubleStreamAttention(nn.Module):
self,
image: torch.FloatTensor,
text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -186,9 +221,7 @@ class QwenDoubleStreamAttention(nn.Module):
joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v)
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :]
@@ -245,6 +278,7 @@ class QwenImageTransformerBlock(nn.Module):
text: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
enable_fp8_attention = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -260,6 +294,7 @@ class QwenImageTransformerBlock(nn.Module):
image=img_modulated,
text=txt_modulated,
image_rotary_emb=image_rotary_emb,
enable_fp8_attention=enable_fp8_attention,
)
image = image + img_gate * img_attn_out