support qwen-image-fp8

This commit is contained in:
Artiprocher
2025-08-07 16:20:50 +08:00
parent d20680baae
commit a0eec8c673
7 changed files with 224 additions and 40 deletions

View File

@@ -1,10 +1,44 @@
import torch import torch, math
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional, Union, List from typing import Tuple, Optional, Union, List
from einops import rearrange from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm from .sd3_dit import TimestepEmbeddings, RMSNorm
from .flux_dit import AdaLayerNorm from .flux_dit import AdaLayerNorm
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
origin_dtype = q.dtype
q_std, k_std, v_std = q.std(), k.std(), v.std()
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
if isinstance(x, tuple):
x = x[0]
x = x.to(origin_dtype) * v_std
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
class ApproximateGELU(nn.Module): class ApproximateGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True): def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
@@ -158,7 +192,8 @@ class QwenDoubleStreamAttention(nn.Module):
self, self,
image: torch.FloatTensor, image: torch.FloatTensor,
text: torch.FloatTensor, text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -186,9 +221,7 @@ class QwenDoubleStreamAttention(nn.Module):
joint_k = torch.cat([txt_k, img_k], dim=2) joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2) joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v) joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :] txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :] img_attn_output = joint_attn_out[:, seq_txt:, :]
@@ -245,6 +278,7 @@ class QwenImageTransformerBlock(nn.Module):
text: torch.Tensor, text: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
enable_fp8_attention = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -260,6 +294,7 @@ class QwenImageTransformerBlock(nn.Module):
image=img_modulated, image=img_modulated,
text=txt_modulated, text=txt_modulated,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
enable_fp8_attention=enable_fp8_attention,
) )
image = image + img_gate * img_attn_out image = image + img_gate * img_attn_out

View File

@@ -62,14 +62,12 @@ class QwenImagePipeline(BasePipeline):
return loss return loss
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
self.vram_management_enabled = True self.vram_management_enabled = True
if num_persistent_param_in_dit is not None: if vram_limit is None:
vram_limit = None vram_limit = self.get_vram()
else: vram_limit = vram_limit - vram_buffer
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None: if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
dtype = next(iter(self.text_encoder.parameters())).dtype dtype = next(iter(self.text_encoder.parameters())).dtype
@@ -95,31 +93,54 @@ class QwenImagePipeline(BasePipeline):
from ..models.qwen_image_dit import RMSNorm from ..models.qwen_image_dit import RMSNorm
dtype = next(iter(self.dit.parameters())).dtype dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device device = "cpu" if vram_limit is not None else self.device
enable_vram_management( if not enable_dit_fp8_computation:
self.dit, enable_vram_management(
module_map = { self.dit,
RMSNorm: AutoWrappedModule, module_map = {
torch.nn.Linear: AutoWrappedLinear, RMSNorm: AutoWrappedModule,
}, torch.nn.Linear: AutoWrappedLinear,
module_config = dict( },
offload_dtype=dtype, module_config = dict(
offload_device="cpu", offload_dtype=dtype,
onload_dtype=dtype, offload_device="cpu",
onload_device=device, onload_dtype=dtype,
computation_dtype=self.torch_dtype, onload_device=device,
computation_device=self.device, computation_dtype=self.torch_dtype,
), computation_device=self.device,
max_num_param=num_persistent_param_in_dit, ),
overflow_module_config = dict( vram_limit=vram_limit,
offload_dtype=dtype, )
offload_device="cpu", else:
onload_dtype=dtype, enable_vram_management(
onload_device="cpu", self.dit,
computation_dtype=self.torch_dtype, module_map = {
computation_device=self.device, RMSNorm: AutoWrappedModule,
), },
vram_limit=vram_limit, module_config = dict(
) offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None: if self.vae is not None:
from ..models.qwen_image_vae import QwenImageRMS_norm from ..models.qwen_image_vae import QwenImageRMS_norm
dtype = next(iter(self.vae.parameters())).dtype dtype = next(iter(self.vae.parameters())).dtype
@@ -190,6 +211,8 @@ class QwenImagePipeline(BasePipeline):
rand_device: str = "cpu", rand_device: str = "cpu",
# Steps # Steps
num_inference_steps: int = 30, num_inference_steps: int = 30,
# FP8 Attention
enable_fp8_attention = False,
# Tile # Tile
tiled: bool = False, tiled: bool = False,
tile_size: int = 128, tile_size: int = 128,
@@ -212,6 +235,7 @@ class QwenImagePipeline(BasePipeline):
"input_image": input_image, "denoising_strength": denoising_strength, "input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
} }
for unit in self.units: for unit in self.units:
@@ -331,6 +355,7 @@ def model_fn_qwen_image(
prompt_emb_mask=None, prompt_emb_mask=None,
height=None, height=None,
width=None, width=None,
enable_fp8_attention=False,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs **kwargs
@@ -355,6 +380,7 @@ def model_fn_qwen_image(
text=text, text=text,
temb=conditioning, temb=conditioning,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
enable_fp8_attention=enable_fp8_attention,
) )
image = dit.norm_out(image, conditioning) image = dit.norm_out(image, conditioning)

View File

@@ -110,8 +110,48 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_A_weights = [] self.lora_A_weights = []
self.lora_B_weights = [] self.lora_B_weights = []
self.lora_merger = None self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
bias = bias.to(torch.bfloat16)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def forward(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs):
# VRAM management
if self.state == 2: if self.state == 2:
weight, bias = self.weight, self.bias weight, bias = self.weight, self.bias
else: else:
@@ -123,8 +163,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
else: else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device) weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
out = torch.nn.functional.linear(x, weight, bias)
# Linear forward
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
# LoRA
if len(self.lora_A_weights) == 0: if len(self.lora_A_weights) == 0:
# No LoRA # No LoRA
return out return out

View File

@@ -172,7 +172,11 @@ After enabling VRAM management, the framework will automatically choose a memory
<summary>Inference Acceleration</summary> <summary>Inference Acceleration</summary>
Inference acceleration for Qwen-Image is under development. Please stay tuned! * FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements.
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers.
* Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
</details> </details>

View File

@@ -172,7 +172,11 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
<summary>推理加速</summary> <summary>推理加速</summary>
Qwen-Image 的推理加速技术正在开发中,敬请期待! * FP8 量化:根据您的硬件与需求,请选择合适的量化方式
* GPU 不支持 FP8 计算(例如 A100、4090 等FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效
* 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
</details> </details>

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management(enable_dit_fp8_computation=True)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
image.save("image.jpg")

View File

@@ -0,0 +1,51 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.models.qwen_image_dit import RMSNorm
from diffsynth.vram_management.layers import enable_vram_management, AutoWrappedLinear, AutoWrappedModule
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
enable_vram_management(
pipe.dit,
module_map = {
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=torch.bfloat16,
offload_device="cuda",
onload_dtype=torch.bfloat16,
onload_device="cuda",
computation_dtype=torch.bfloat16,
computation_device="cuda",
),
vram_limit=None,
)
enable_vram_management(
pipe.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=torch.float8_e4m3fn,
offload_device="cuda",
onload_dtype=torch.float8_e4m3fn,
onload_device="cuda",
computation_dtype=torch.float8_e4m3fn,
computation_device="cuda",
),
vram_limit=None,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
image.save("image.jpg")