diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py
index 15f8747..423a0a5 100644
--- a/diffsynth/models/qwen_image_dit.py
+++ b/diffsynth/models/qwen_image_dit.py
@@ -1,10 +1,44 @@
-import torch
+import torch, math
import torch.nn as nn
from typing import Tuple, Optional, Union, List
from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm
from .flux_dit import AdaLayerNorm
+try:
+ import flash_attn_interface
+ FLASH_ATTN_3_AVAILABLE = True
+except ModuleNotFoundError:
+ FLASH_ATTN_3_AVAILABLE = False
+
+
+def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, enable_fp8_attention: bool = False):
+ if FLASH_ATTN_3_AVAILABLE:
+ if not enable_fp8_attention:
+ q = rearrange(q, "b n s d -> b s n d", n=num_heads)
+ k = rearrange(k, "b n s d -> b s n d", n=num_heads)
+ v = rearrange(v, "b n s d -> b s n d", n=num_heads)
+ x = flash_attn_interface.flash_attn_func(q, k, v)
+ if isinstance(x, tuple):
+ x = x[0]
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
+ else:
+ origin_dtype = q.dtype
+ q_std, k_std, v_std = q.std(), k.std(), v.std()
+ q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
+ q = rearrange(q, "b n s d -> b s n d", n=num_heads)
+ k = rearrange(k, "b n s d -> b s n d", n=num_heads)
+ v = rearrange(v, "b n s d -> b s n d", n=num_heads)
+ x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
+ if isinstance(x, tuple):
+ x = x[0]
+ x = x.to(origin_dtype) * v_std
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
+ else:
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
+ return x
+
class ApproximateGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
@@ -158,7 +192,8 @@ class QwenDoubleStreamAttention(nn.Module):
self,
image: torch.FloatTensor,
text: torch.FloatTensor,
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -186,9 +221,7 @@ class QwenDoubleStreamAttention(nn.Module):
joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2)
- joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v)
-
- joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
+ joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :]
@@ -245,6 +278,7 @@ class QwenImageTransformerBlock(nn.Module):
text: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ enable_fp8_attention = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -260,6 +294,7 @@ class QwenImageTransformerBlock(nn.Module):
image=img_modulated,
text=txt_modulated,
image_rotary_emb=image_rotary_emb,
+ enable_fp8_attention=enable_fp8_attention,
)
image = image + img_gate * img_attn_out
diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py
index deccd62..ad8e256 100644
--- a/diffsynth/pipelines/qwen_image.py
+++ b/diffsynth/pipelines/qwen_image.py
@@ -62,14 +62,12 @@ class QwenImagePipeline(BasePipeline):
return loss
- def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
+ def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
self.vram_management_enabled = True
- if num_persistent_param_in_dit is not None:
- vram_limit = None
- else:
- if vram_limit is None:
- vram_limit = self.get_vram()
- vram_limit = vram_limit - vram_buffer
+ if vram_limit is None:
+ vram_limit = self.get_vram()
+ vram_limit = vram_limit - vram_buffer
+
if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
dtype = next(iter(self.text_encoder.parameters())).dtype
@@ -95,31 +93,54 @@ class QwenImagePipeline(BasePipeline):
from ..models.qwen_image_dit import RMSNorm
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
- enable_vram_management(
- self.dit,
- module_map = {
- RMSNorm: AutoWrappedModule,
- torch.nn.Linear: AutoWrappedLinear,
- },
- module_config = dict(
- offload_dtype=dtype,
- offload_device="cpu",
- onload_dtype=dtype,
- onload_device=device,
- computation_dtype=self.torch_dtype,
- computation_device=self.device,
- ),
- max_num_param=num_persistent_param_in_dit,
- overflow_module_config = dict(
- offload_dtype=dtype,
- offload_device="cpu",
- onload_dtype=dtype,
- onload_device="cpu",
- computation_dtype=self.torch_dtype,
- computation_device=self.device,
- ),
- vram_limit=vram_limit,
- )
+ if not enable_dit_fp8_computation:
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ else:
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=self.torch_dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
+ enable_vram_management(
+ self.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=dtype,
+ offload_device="cpu",
+ onload_dtype=dtype,
+ onload_device=device,
+ computation_dtype=dtype,
+ computation_device=self.device,
+ ),
+ vram_limit=vram_limit,
+ )
if self.vae is not None:
from ..models.qwen_image_vae import QwenImageRMS_norm
dtype = next(iter(self.vae.parameters())).dtype
@@ -190,6 +211,8 @@ class QwenImagePipeline(BasePipeline):
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
+ # FP8 Attention
+ enable_fp8_attention = False,
# Tile
tiled: bool = False,
tile_size: int = 128,
@@ -212,6 +235,7 @@ class QwenImagePipeline(BasePipeline):
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
+ "enable_fp8_attention": enable_fp8_attention,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
}
for unit in self.units:
@@ -331,6 +355,7 @@ def model_fn_qwen_image(
prompt_emb_mask=None,
height=None,
width=None,
+ enable_fp8_attention=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs
@@ -355,6 +380,7 @@ def model_fn_qwen_image(
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
+ enable_fp8_attention=enable_fp8_attention,
)
image = dit.norm_out(image, conditioning)
diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py
index 0ebb054..c6f1ec6 100644
--- a/diffsynth/vram_management/layers.py
+++ b/diffsynth/vram_management/layers.py
@@ -110,8 +110,48 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
+ self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
+
+ def fp8_linear(
+ self,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ device = input.device
+ origin_dtype = input.dtype
+ origin_shape = input.shape
+ input = input.reshape(-1, origin_shape[-1])
+
+ x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
+ fp8_max = 448.0
+ # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
+ # To avoid overflow and ensure numerical compatibility during FP8 computation,
+ # we scale down the input by 2.0 in advance.
+ # This scaling will be compensated later during the final result scaling.
+ if self.computation_dtype == torch.float8_e4m3fnuz:
+ fp8_max = fp8_max / 2.0
+ scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
+ scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
+ input = input / (scale_a + 1e-8)
+ input = input.to(self.computation_dtype)
+ weight = weight.to(self.computation_dtype)
+ bias = bias.to(torch.bfloat16)
+
+ result = torch._scaled_mm(
+ input,
+ weight.T,
+ scale_a=scale_a,
+ scale_b=scale_b.T,
+ bias=bias,
+ out_dtype=origin_dtype,
+ )
+ new_shape = origin_shape[:-1] + result.shape[-1:]
+ result = result.reshape(new_shape)
+ return result
def forward(self, x, *args, **kwargs):
+ # VRAM management
if self.state == 2:
weight, bias = self.weight, self.bias
else:
@@ -123,8 +163,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
- out = torch.nn.functional.linear(x, weight, bias)
+ # Linear forward
+ if self.enable_fp8:
+ out = self.fp8_linear(x, weight, bias)
+ else:
+ out = torch.nn.functional.linear(x, weight, bias)
+
+ # LoRA
if len(self.lora_A_weights) == 0:
# No LoRA
return out
diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md
index c9fd4ae..cff180c 100644
--- a/examples/qwen_image/README.md
+++ b/examples/qwen_image/README.md
@@ -172,7 +172,11 @@ After enabling VRAM management, the framework will automatically choose a memory
Inference Acceleration
-Inference acceleration for Qwen-Image is under development. Please stay tuned!
+* FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements.
+ * GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
+ * GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers.
+ * Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
+ * Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md
index 0a311c1..9af0efd 100644
--- a/examples/qwen_image/README_zh.md
+++ b/examples/qwen_image/README_zh.md
@@ -172,7 +172,11 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
推理加速
-Qwen-Image 的推理加速技术正在开发中,敬请期待!
+* FP8 量化:根据您的硬件与需求,请选择合适的量化方式
+ * GPU 不支持 FP8 计算(例如 A100、4090 等):FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
+ * GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效
+ * 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
+ * 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
diff --git a/examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py b/examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py
new file mode 100644
index 0000000..2d40316
--- /dev/null
+++ b/examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py
@@ -0,0 +1,18 @@
+from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
+import torch
+
+
+pipe = QwenImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
+)
+pipe.enable_vram_management(enable_dit_fp8_computation=True)
+prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
+image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
+image.save("image.jpg")
diff --git a/examples/qwen_image/accelerate/Qwen-Image-FP8.py b/examples/qwen_image/accelerate/Qwen-Image-FP8.py
new file mode 100644
index 0000000..9441407
--- /dev/null
+++ b/examples/qwen_image/accelerate/Qwen-Image-FP8.py
@@ -0,0 +1,51 @@
+from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
+from diffsynth.models.qwen_image_dit import RMSNorm
+from diffsynth.vram_management.layers import enable_vram_management, AutoWrappedLinear, AutoWrappedModule
+import torch
+
+
+pipe = QwenImagePipeline.from_pretrained(
+ torch_dtype=torch.bfloat16,
+ device="cuda",
+ model_configs=[
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
+ ],
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
+)
+
+enable_vram_management(
+ pipe.dit,
+ module_map = {
+ RMSNorm: AutoWrappedModule,
+ },
+ module_config = dict(
+ offload_dtype=torch.bfloat16,
+ offload_device="cuda",
+ onload_dtype=torch.bfloat16,
+ onload_device="cuda",
+ computation_dtype=torch.bfloat16,
+ computation_device="cuda",
+ ),
+ vram_limit=None,
+)
+enable_vram_management(
+ pipe.dit,
+ module_map = {
+ torch.nn.Linear: AutoWrappedLinear,
+ },
+ module_config = dict(
+ offload_dtype=torch.float8_e4m3fn,
+ offload_device="cuda",
+ onload_dtype=torch.float8_e4m3fn,
+ onload_device="cuda",
+ computation_dtype=torch.float8_e4m3fn,
+ computation_device="cuda",
+ ),
+ vram_limit=None,
+)
+
+prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
+image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
+image.save("image.jpg")