diff --git a/diffsynth/models/qwen_image_dit.py b/diffsynth/models/qwen_image_dit.py index 15f8747..423a0a5 100644 --- a/diffsynth/models/qwen_image_dit.py +++ b/diffsynth/models/qwen_image_dit.py @@ -1,10 +1,44 @@ -import torch +import torch, math import torch.nn as nn from typing import Tuple, Optional, Union, List from einops import rearrange from .sd3_dit import TimestepEmbeddings, RMSNorm from .flux_dit import AdaLayerNorm +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + + +def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, enable_fp8_attention: bool = False): + if FLASH_ATTN_3_AVAILABLE: + if not enable_fp8_attention: + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v) + if isinstance(x, tuple): + x = x[0] + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + origin_dtype = q.dtype + q_std, k_std, v_std = q.std(), k.std(), v.std() + q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn) + q = rearrange(q, "b n s d -> b s n d", n=num_heads) + k = rearrange(k, "b n s d -> b s n d", n=num_heads) + v = rearrange(v, "b n s d -> b s n d", n=num_heads) + x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1))) + if isinstance(x, tuple): + x = x[0] + x = x.to(origin_dtype) * v_std + x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) + return x + class ApproximateGELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True): @@ -158,7 +192,8 @@ class QwenDoubleStreamAttention(nn.Module): self, image: torch.FloatTensor, text: torch.FloatTensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + enable_fp8_attention: bool = False, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) @@ -186,9 +221,7 @@ class QwenDoubleStreamAttention(nn.Module): joint_k = torch.cat([txt_k, img_k], dim=2) joint_v = torch.cat([txt_v, img_v], dim=2) - joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v) - - joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype) + joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype) txt_attn_output = joint_attn_out[:, :seq_txt, :] img_attn_output = joint_attn_out[:, seq_txt:, :] @@ -245,6 +278,7 @@ class QwenImageTransformerBlock(nn.Module): text: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + enable_fp8_attention = False, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each @@ -260,6 +294,7 @@ class QwenImageTransformerBlock(nn.Module): image=img_modulated, text=txt_modulated, image_rotary_emb=image_rotary_emb, + enable_fp8_attention=enable_fp8_attention, ) image = image + img_gate * img_attn_out diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index deccd62..ad8e256 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -62,14 +62,12 @@ class QwenImagePipeline(BasePipeline): return loss - def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): + def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False): self.vram_management_enabled = True - if num_persistent_param_in_dit is not None: - vram_limit = None - else: - if vram_limit is None: - vram_limit = self.get_vram() - vram_limit = vram_limit - vram_buffer + if vram_limit is None: + vram_limit = self.get_vram() + vram_limit = vram_limit - vram_buffer + if self.text_encoder is not None: from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm dtype = next(iter(self.text_encoder.parameters())).dtype @@ -95,31 +93,54 @@ class QwenImagePipeline(BasePipeline): from ..models.qwen_image_dit import RMSNorm dtype = next(iter(self.dit.parameters())).dtype device = "cpu" if vram_limit is not None else self.device - enable_vram_management( - self.dit, - module_map = { - RMSNorm: AutoWrappedModule, - torch.nn.Linear: AutoWrappedLinear, - }, - module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device=device, - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - max_num_param=num_persistent_param_in_dit, - overflow_module_config = dict( - offload_dtype=dtype, - offload_device="cpu", - onload_dtype=dtype, - onload_device="cpu", - computation_dtype=self.torch_dtype, - computation_device=self.device, - ), - vram_limit=vram_limit, - ) + if not enable_dit_fp8_computation: + enable_vram_management( + self.dit, + module_map = { + RMSNorm: AutoWrappedModule, + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + else: + enable_vram_management( + self.dit, + module_map = { + RMSNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) + enable_vram_management( + self.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device=device, + computation_dtype=dtype, + computation_device=self.device, + ), + vram_limit=vram_limit, + ) if self.vae is not None: from ..models.qwen_image_vae import QwenImageRMS_norm dtype = next(iter(self.vae.parameters())).dtype @@ -190,6 +211,8 @@ class QwenImagePipeline(BasePipeline): rand_device: str = "cpu", # Steps num_inference_steps: int = 30, + # FP8 Attention + enable_fp8_attention = False, # Tile tiled: bool = False, tile_size: int = 128, @@ -212,6 +235,7 @@ class QwenImagePipeline(BasePipeline): "input_image": input_image, "denoising_strength": denoising_strength, "height": height, "width": width, "seed": seed, "rand_device": rand_device, + "enable_fp8_attention": enable_fp8_attention, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, } for unit in self.units: @@ -331,6 +355,7 @@ def model_fn_qwen_image( prompt_emb_mask=None, height=None, width=None, + enable_fp8_attention=False, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs @@ -355,6 +380,7 @@ def model_fn_qwen_image( text=text, temb=conditioning, image_rotary_emb=image_rotary_emb, + enable_fp8_attention=enable_fp8_attention, ) image = dit.norm_out(image, conditioning) diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index 0ebb054..c6f1ec6 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -110,8 +110,48 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): self.lora_A_weights = [] self.lora_B_weights = [] self.lora_merger = None + self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + + def fp8_linear( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + device = input.device + origin_dtype = input.dtype + origin_shape = input.shape + input = input.reshape(-1, origin_shape[-1]) + + x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values + fp8_max = 448.0 + # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn. + # To avoid overflow and ensure numerical compatibility during FP8 computation, + # we scale down the input by 2.0 in advance. + # This scaling will be compensated later during the final result scaling. + if self.computation_dtype == torch.float8_e4m3fnuz: + fp8_max = fp8_max / 2.0 + scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device) + scale_b = torch.ones((weight.shape[0], 1)).to(device=device) + input = input / (scale_a + 1e-8) + input = input.to(self.computation_dtype) + weight = weight.to(self.computation_dtype) + bias = bias.to(torch.bfloat16) + + result = torch._scaled_mm( + input, + weight.T, + scale_a=scale_a, + scale_b=scale_b.T, + bias=bias, + out_dtype=origin_dtype, + ) + new_shape = origin_shape[:-1] + result.shape[-1:] + result = result.reshape(new_shape) + return result def forward(self, x, *args, **kwargs): + # VRAM management if self.state == 2: weight, bias = self.weight, self.bias else: @@ -123,8 +163,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): else: weight = cast_to(self.weight, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) - out = torch.nn.functional.linear(x, weight, bias) + # Linear forward + if self.enable_fp8: + out = self.fp8_linear(x, weight, bias) + else: + out = torch.nn.functional.linear(x, weight, bias) + + # LoRA if len(self.lora_A_weights) == 0: # No LoRA return out diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md index c9fd4ae..cff180c 100644 --- a/examples/qwen_image/README.md +++ b/examples/qwen_image/README.md @@ -172,7 +172,11 @@ After enabling VRAM management, the framework will automatically choose a memory Inference Acceleration -Inference acceleration for Qwen-Image is under development. Please stay tuned! +* FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements. + * GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py) + * GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers. + * Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py) + * Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py) diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md index 0a311c1..9af0efd 100644 --- a/examples/qwen_image/README_zh.md +++ b/examples/qwen_image/README_zh.md @@ -172,7 +172,11 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在 推理加速 -Qwen-Image 的推理加速技术正在开发中,敬请期待! +* FP8 量化:根据您的硬件与需求,请选择合适的量化方式 + * GPU 不支持 FP8 计算(例如 A100、4090 等):FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py) + * GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效 + * 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py) + * 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py) diff --git a/examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py b/examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py new file mode 100644 index 0000000..2d40316 --- /dev/null +++ b/examples/qwen_image/accelerate/Qwen-Image-FP8-offload.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.enable_vram_management(enable_dit_fp8_computation=True) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True) +image.save("image.jpg") diff --git a/examples/qwen_image/accelerate/Qwen-Image-FP8.py b/examples/qwen_image/accelerate/Qwen-Image-FP8.py new file mode 100644 index 0000000..9441407 --- /dev/null +++ b/examples/qwen_image/accelerate/Qwen-Image-FP8.py @@ -0,0 +1,51 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.models.qwen_image_dit import RMSNorm +from diffsynth.vram_management.layers import enable_vram_management, AutoWrappedLinear, AutoWrappedModule +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) + +enable_vram_management( + pipe.dit, + module_map = { + RMSNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=torch.bfloat16, + offload_device="cuda", + onload_dtype=torch.bfloat16, + onload_device="cuda", + computation_dtype=torch.bfloat16, + computation_device="cuda", + ), + vram_limit=None, +) +enable_vram_management( + pipe.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=torch.float8_e4m3fn, + offload_device="cuda", + onload_dtype=torch.float8_e4m3fn, + onload_device="cuda", + computation_dtype=torch.float8_e4m3fn, + computation_device="cuda", + ), + vram_limit=None, +) + +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True) +image.save("image.jpg")