Qwen-Image FP8 (#761)

* support qwen-image-fp8

* refine README

* bugfix

* bugfix
This commit is contained in:
Zhongjie Duan
2025-08-07 16:56:02 +08:00
committed by GitHub
parent 4f7c3b6a1e
commit 32cf5d32ce
7 changed files with 225 additions and 39 deletions

View File

@@ -1,10 +1,44 @@
import torch
import torch, math
import torch.nn as nn
from typing import Tuple, Optional, Union, List
from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm
from .flux_dit import AdaLayerNorm
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
origin_dtype = q.dtype
q_std, k_std, v_std = q.std(), k.std(), v.std()
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
if isinstance(x, tuple):
x = x[0]
x = x.to(origin_dtype) * v_std
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
class ApproximateGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
@@ -160,6 +194,7 @@ class QwenDoubleStreamAttention(nn.Module):
text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -187,9 +222,7 @@ class QwenDoubleStreamAttention(nn.Module):
joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v, attn_mask=attention_mask)
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :]
@@ -247,6 +280,7 @@ class QwenImageTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -263,6 +297,7 @@ class QwenImageTransformerBlock(nn.Module):
text=txt_modulated,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
)
image = image + img_gate * img_attn_out

View File

@@ -63,14 +63,12 @@ class QwenImagePipeline(BasePipeline):
return loss
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
dtype = next(iter(self.text_encoder.parameters())).dtype
@@ -96,31 +94,54 @@ class QwenImagePipeline(BasePipeline):
from ..models.qwen_image_dit import RMSNorm
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if not enable_dit_fp8_computation:
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
else:
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
from ..models.qwen_image_vae import QwenImageRMS_norm
dtype = next(iter(self.vae.parameters())).dtype
@@ -195,6 +216,8 @@ class QwenImagePipeline(BasePipeline):
eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None,
eligen_enable_on_negative: bool = False,
# FP8
enable_fp8_attention: bool = False,
# Tile
tiled: bool = False,
tile_size: int = 128,
@@ -217,6 +240,7 @@ class QwenImagePipeline(BasePipeline):
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
}
@@ -418,6 +442,7 @@ def model_fn_qwen_image(
entity_prompt_emb=None,
entity_prompt_emb_mask=None,
entity_masks=None,
enable_fp8_attention=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs
@@ -451,6 +476,7 @@ def model_fn_qwen_image(
temb=conditioning,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
)
image = dit.norm_out(image, conditioning)

View File

@@ -110,8 +110,48 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
bias = bias.to(torch.bfloat16)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def forward(self, x, *args, **kwargs):
# VRAM management
if self.state == 2:
weight, bias = self.weight, self.bias
else:
@@ -123,8 +163,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
out = torch.nn.functional.linear(x, weight, bias)
# Linear forward
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
# LoRA
if len(self.lora_A_weights) == 0:
# No LoRA
return out