From 32cf5d32cebedf4992ebc742cd15fc0901b2f2b1 Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Thu, 7 Aug 2025 16:56:02 +0800 Subject: [PATCH] Qwen-Image FP8 (#761) * support qwen-image-fp8 * refine README * bugfix * bugfix --- diffsynth/models/qwen_image_dit.py | 43 ++++++++- diffsynth/pipelines/qwen_image.py | 90 ++++++++++++------- diffsynth/vram_management/layers.py | 48 +++++++++- examples/qwen_image/README.md | 7 +- examples/qwen_image/README_zh.md | 7 +- .../accelerate/Qwen-Image-FP8-offload.py | 18 ++++ .../qwen_image/accelerate/Qwen-Image-FP8.py | 51 +++++++++++ 7 files changed, 225 insertions(+), 39 deletions(-) create mode 100644 examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py create mode 100644 examples/qwen_image/accelerate/Qwen-Image-FP8.py diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py index b8f92bb..919dd82 100644 --- a/diffsynth/models/qwen_image_dit.py +++ b/diffsynth/models/qwen_image_dit.py @@ -1,10 +1,44 @@ -import torch +import torch, math import torch.nn as nn from typing import Tuple, Optional, Union, List from einops import rearrange from .sd3_dit import TimestepEmbeddings, RMSNorm from .flux_dit import AdaLayerNorm +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + + +def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False): + if FLASH_ATTN_3_AVAILABLE and attention_mask is None: + if not enable_fp8_attention: + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x, tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + origin_dtype = q.dtype + q_std, k_std, v_std = q.std(), k.std(), v.std() + q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn) + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1))) + if isinstance(x, tuple): + x = x[0] + x = x.to(origin_dtype) * v_std + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + class ApproximateGELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True): @@ -160,6 +194,7 @@ class QwenDoubleStreamAttention(nn.Module): text: torch.FloatTensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, + enable_fp8_attention: bool = False, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) @@ -187,9 +222,7 @@ class QwenDoubleStreamAttention(nn.Module): joint_k = torch.cat([txt_k, img_k], dim=2) joint_v = torch.cat([txt_v, img_v], dim=2) - joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v, attn_mask=attention_mask) - - joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype) + joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype) txt_attn_output = joint_attn_out[:, :seq_txt, :] img_attn_output = joint_attn_out[:, seq_txt:, :] @@ -247,6 +280,7 @@ class QwenImageTransformerBlock(nn.Module): temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, + enable_fp8_attention = False, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each @@ -263,6 +297,7 @@ class QwenImageTransformerBlock(nn.Module): text=txt_modulated, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, + enable_fp8_attention=enable_fp8_attention, ) image = image + img_gate * img_attn_out diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 0611a7d..3e952c0 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -63,14 +63,12 @@ class QwenImagePipeline(BasePipeline): return loss - def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): + def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False): self.vram_management_enabled = True - if num_persistent_param_in_dit is not None: - vram_limit = None - else: - if vram_limit is None: - vram_limit = self.get_vram() - vram_limit = vram_limit - vram_buffer + if vram_limit is None: + vram_limit = self.get_vram() + vram_limit = vram_limit - vram_buffer + if self.text_encoder is not None: from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm dtype = next(iter(self.text_encoder.parameters())).dtype @@ -96,31 +94,54 @@ class QwenImagePipeline(BasePipeline): from ..models.qwen_image_dit import RMSNorm dtype = next(iter(self.dit.parameters())).dtype device = "cpu" if vram_limit is not None else self.device - enable_vram_management( - self.dit, - module_map = { - RMSNorm: AutoWrappedModule, - torch.nn.Linear: AutoWrappedLinear, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device=device, - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - max_num_param=num_persistent_param_in_dit, - overflow_module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - vram_limit=vram_limit, - ) + if not enable_dit_fp8_computation: + enable_vram_management( + self.dit, + module_map = { + RMSNorm: AutoWrappedModule, + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + else: + enable_vram_management( + self.dit, + module_map = { + RMSNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + enable_vram_management( + self.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) if self.vae is not None: from ..models.qwen_image_vae import QwenImageRMS_norm dtype = next(iter(self.vae.parameters())).dtype @@ -195,6 +216,8 @@ class QwenImagePipeline(BasePipeline): eligen_entity_prompts: list[str] = None, eligen_entity_masks: list[Image.Image] = None, eligen_enable_on_negative: bool = False, + # FP8 + enable_fp8_attention: bool = False, # Tile tiled: bool = False, tile_size: int = 128, @@ -217,6 +240,7 @@ class QwenImagePipeline(BasePipeline): "input_image": input_image, "denoising_strength": denoising_strength, "height": height, "width": width, "seed": seed, "rand_device": rand_device, + "enable_fp8_attention": enable_fp8_attention, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, } @@ -418,6 +442,7 @@ def model_fn_qwen_image( entity_prompt_emb=None, entity_prompt_emb_mask=None, entity_masks=None, + enable_fp8_attention=False, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs @@ -451,6 +476,7 @@ def model_fn_qwen_image( temb=conditioning, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, + enable_fp8_attention=enable_fp8_attention, ) image = dit.norm_out(image, conditioning) diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index 0ebb054..c6f1ec6 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -110,8 +110,48 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): self.lora_A_weights = [] self.lora_B_weights = [] self.lora_merger = None + self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + + def fp8_linear( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + device = input.device + origin_dtype = input.dtype + origin_shape = input.shape + input = input.reshape(-1, origin_shape[-1]) + + x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values + fp8_max = 448.0 + # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn. + # To avoid overflow and ensure numerical compatibility during FP8 computation, + # we scale down the input by 2.0 in advance. + # This scaling will be compensated later during the final result scaling. + if self.computation_dtype == torch.float8_e4m3fnuz: + fp8_max = fp8_max / 2.0 + scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device) + scale_b = torch.ones((weight.shape[0], 1)).to(device=device) + input = input / (scale_a + 1e-8) + input = input.to(self.computation_dtype) + weight = weight.to(self.computation_dtype) + bias = bias.to(torch.bfloat16) + + result = torch._scaled_mm( + input, + weight.T, + scale_a=scale_a, + scale_b=scale_b.T, + bias=bias, + out_dtype=origin_dtype, + ) + new_shape = origin_shape[:-1] + result.shape[-1:] + result = result.reshape(new_shape) + return result def forward(self, x, *args, **kwargs): + # VRAM management if self.state == 2: weight, bias = self.weight, self.bias else: @@ -123,8 +163,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): else: weight = cast_to(self.weight, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) - out = torch.nn.functional.linear(x, weight, bias) + # Linear forward + if self.enable_fp8: + out = self.fp8_linear(x, weight, bias) + else: + out = torch.nn.functional.linear(x, weight, bias) + + # LoRA if len(self.lora_A_weights) == 0: # No LoRA return out diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md index 6eaa533..aba8120 100644 --- a/examples/qwen_image/README.md +++ b/examples/qwen_image/README.md @@ -164,6 +164,7 @@ After enabling VRAM management, the framework will automatically choose a memory * `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not a strict limit. If the set limit is too low but actual free VRAM is enough, the model will run with minimal VRAM use. Set it to 0 for the smallest possible VRAM use. * `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because large network layers may use more VRAM than expected during loading. The best value is the VRAM size of the largest model layer. * `num_persistent_param_in_dit`: Number of parameters to keep in VRAM in the DiT model. Default is no limit. This option will be removed in the future. Do not rely on it. +* `enable_dit_fp8_computation`: Whether to enable FP8 computation in the DiT model. This is only applicable to GPUs that support FP8 operations (e.g., H200, etc.). Disabled by default. @@ -172,7 +173,11 @@ After enabling VRAM management, the framework will automatically choose a memory Inference Acceleration -Inference acceleration for Qwen-Image is under development. Please stay tuned! +* FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements. + * GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py) + * GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers. + * Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py) + * Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py) diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md index 305e1ab..1259dc3 100644 --- a/examples/qwen_image/README_zh.md +++ b/examples/qwen_image/README_zh.md @@ -164,6 +164,7 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在 * `vram_limit`: 显存占用量限制(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。将其设置为0时,将会实现理论最小显存占用。 * `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。 * `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。 +* `enable_dit_fp8_computation`: 是否启用 DiT 模型中的 FP8 计算,仅适用于支持 FP8 运算的 GPU(例如 H200 等),默认不启用。 @@ -172,7 +173,11 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在 推理加速 -Qwen-Image 的推理加速技术正在开发中,敬请期待! +* FP8 量化:根据您的硬件与需求,请选择合适的量化方式 + * GPU 不支持 FP8 计算(例如 A100、4090 等):FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py) + * GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效 + * 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py) + * 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py) diff --git a/examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py b/examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py new file mode 100644 index 0000000..2d40316 --- /dev/null +++ b/examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.enable_vram_management(enable_dit_fp8_computation=True) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True) +image.save("image.jpg") diff --git a/examples/qwen_image/accelerate/Qwen-Image-FP8.py b/examples/qwen_image/accelerate/Qwen-Image-FP8.py new file mode 100644 index 0000000..9441407 --- /dev/null +++ b/examples/qwen_image/accelerate/Qwen-Image-FP8.py @@ -0,0 +1,51 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.models.qwen_image_dit import RMSNorm +from diffsynth.vram_management.layers import enable_vram_management, AutoWrappedLinear, AutoWrappedModule +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +enable_vram_management( + pipe.dit, + module_map = { + RMSNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=torch.bfloat16, + offload_device="cuda", + onload_dtype=torch.bfloat16, + onload_device="cuda", + computation_dtype=torch.bfloat16, + computation_device="cuda", + ), + vram_limit=None, +) +enable_vram_management( + pipe.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=torch.float8_e4m3fn, + offload_device="cuda", + onload_dtype=torch.float8_e4m3fn, + onload_device="cuda", + computation_dtype=torch.float8_e4m3fn, + computation_device="cuda", + ), + vram_limit=None, +) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True) +image.save("image.jpg")