Compare commits

...

3 Commits
dpo ... fp8

Author SHA1 Message Date
Zhongjie Duan
17714a8cc8 Merge branch 'main' into fp8 2025-08-07 16:40:44 +08:00
Artiprocher
a947459bda refine README 2025-08-07 16:32:01 +08:00
Artiprocher
a0eec8c673 support qwen-image-fp8 2025-08-07 16:20:50 +08:00
7 changed files with 224 additions and 40 deletions

View File

@@ -1,10 +1,44 @@
import torch
import torch, math
import torch.nn as nn
from typing import Tuple, Optional, Union, List
from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm
from .flux_dit import AdaLayerNorm
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
origin_dtype = q.dtype
q_std, k_std, v_std = q.std(), k.std(), v.std()
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
if isinstance(x, tuple):
x = x[0]
x = x.to(origin_dtype) * v_std
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
class ApproximateGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
@@ -160,6 +194,7 @@ class QwenDoubleStreamAttention(nn.Module):
text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -187,9 +222,7 @@ class QwenDoubleStreamAttention(nn.Module):
joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v, attn_mask=attention_mask)
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :]
@@ -247,6 +280,7 @@ class QwenImageTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -262,7 +296,8 @@ class QwenImageTransformerBlock(nn.Module):
image=img_modulated,
text=txt_modulated,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
)
image = image + img_gate * img_attn_out

View File

@@ -63,14 +63,12 @@ class QwenImagePipeline(BasePipeline):
return loss
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
dtype = next(iter(self.text_encoder.parameters())).dtype
@@ -96,31 +94,54 @@ class QwenImagePipeline(BasePipeline):
from ..models.qwen_image_dit import RMSNorm
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if not enable_dit_fp8_computation:
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
else:
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None:
from ..models.qwen_image_vae import QwenImageRMS_norm
dtype = next(iter(self.vae.parameters())).dtype
@@ -217,6 +238,7 @@ class QwenImagePipeline(BasePipeline):
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
}
@@ -418,6 +440,7 @@ def model_fn_qwen_image(
entity_prompt_emb=None,
entity_prompt_emb_mask=None,
entity_masks=None,
enable_fp8_attention=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs
@@ -451,6 +474,7 @@ def model_fn_qwen_image(
temb=conditioning,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
)
image = dit.norm_out(image, conditioning)

View File

@@ -110,8 +110,48 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
bias = bias.to(torch.bfloat16)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def forward(self, x, *args, **kwargs):
# VRAM management
if self.state == 2:
weight, bias = self.weight, self.bias
else:
@@ -123,8 +163,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
out = torch.nn.functional.linear(x, weight, bias)
# Linear forward
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
# LoRA
if len(self.lora_A_weights) == 0:
# No LoRA
return out

View File

@@ -164,6 +164,7 @@ After enabling VRAM management, the framework will automatically choose a memory
* `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not a strict limit. If the set limit is too low but actual free VRAM is enough, the model will run with minimal VRAM use. Set it to 0 for the smallest possible VRAM use.
* `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because large network layers may use more VRAM than expected during loading. The best value is the VRAM size of the largest model layer.
* `num_persistent_param_in_dit`: Number of parameters to keep in VRAM in the DiT model. Default is no limit. This option will be removed in the future. Do not rely on it.
* `enable_dit_fp8_computation`: Whether to enable FP8 computation in the DiT model. This is only applicable to GPUs that support FP8 operations (e.g., H200, etc.). Disabled by default.
</details>
@@ -172,7 +173,11 @@ After enabling VRAM management, the framework will automatically choose a memory
<summary>Inference Acceleration</summary>
Inference acceleration for Qwen-Image is under development. Please stay tuned!
* FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements.
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers.
* Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
</details>

View File

@@ -164,6 +164,7 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
* `vram_limit`: 显存占用量限制GB默认占用设备上的剩余显存。注意这不是一个绝对限制当设置的显存不足以支持模型进行推理但实际可用显存足够时将会以最小化显存占用的形式进行推理。将其设置为0时将会实现理论最小显存占用。
* `vram_buffer`: 显存缓冲区大小GB默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
* `enable_dit_fp8_computation`: 是否启用 DiT 模型中的 FP8 计算,仅适用于支持 FP8 运算的 GPU例如 H200 等),默认不启用。
</details>
@@ -172,7 +173,11 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
<summary>推理加速</summary>
Qwen-Image 的推理加速技术正在开发中,敬请期待!
* FP8 量化:根据您的硬件与需求,请选择合适的量化方式
* GPU 不支持 FP8 计算(例如 A100、4090 等FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效
* 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
</details>

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management(enable_dit_fp8_computation=True)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
image.save("image.jpg")

View File

@@ -0,0 +1,51 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.models.qwen_image_dit import RMSNorm
from diffsynth.vram_management.layers import enable_vram_management, AutoWrappedLinear, AutoWrappedModule
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
enable_vram_management(
pipe.dit,
module_map = {
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=torch.bfloat16,
offload_device="cuda",
onload_dtype=torch.bfloat16,
onload_device="cuda",
computation_dtype=torch.bfloat16,
computation_device="cuda",
),
vram_limit=None,
)
enable_vram_management(
pipe.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=torch.float8_e4m3fn,
offload_device="cuda",
onload_dtype=torch.float8_e4m3fn,
onload_device="cuda",
computation_dtype=torch.float8_e4m3fn,
computation_device="cuda",
),
vram_limit=None,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
image.save("image.jpg")