Compare commits

...

3 Commits

Author SHA1 Message Date
Artiprocher
94e7e800b2 qwen-image-acc-adapter 2025-08-08 16:56:06 +08:00
Artiprocher
e85f42b474 qwen-image-acc-adapter 2025-08-08 16:51:42 +08:00
Zhongjie Duan
32cf5d32ce Qwen-Image FP8 (#761)
* support qwen-image-fp8

* refine README

* bugfix

* bugfix
2025-08-07 16:56:02 +08:00
12 changed files with 355 additions and 45 deletions

View File

@@ -72,6 +72,7 @@ from ..models.flux_lora_encoder import FluxLoRAEncoder
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
from ..models.nexus_gen import NexusGenAutoregressiveModel from ..models.nexus_gen import NexusGenAutoregressiveModel
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_vae import QwenImageVAE
@@ -165,6 +166,7 @@ model_loader_configs = [
(None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"), (None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"),
(None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"), (None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"),
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"), (None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
(None, "ae9d13bfc578702baf6445d2cf3d1d46", ["qwen_image_accelerate_adapter"], [QwenImageAccelerateAdapter], "civitai"),
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"), (None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"), (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
] ]

View File

@@ -0,0 +1,63 @@
from .qwen_image_dit import QwenImageTransformerBlock, AdaLayerNorm, TimestepEmbeddings
from einops import rearrange
import torch
class QwenImageAccelerateAdapter(torch.nn.Module):
def __init__(
self,
num_layers: int = 1,
):
super().__init__()
self.proj_latents_in = torch.nn.Linear(64, 3072)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
self.transformer_blocks = torch.nn.ModuleList(
[
QwenImageTransformerBlock(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNorm(3072, single=True)
self.proj_out = torch.nn.Linear(3072, 64)
self.proj_latents_out = torch.nn.Linear(64, 64)
def forward(
self,
latents=None,
image=None,
text=None,
image_rotary_emb=None,
timestep=None,
):
latents = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
image = image + self.proj_latents_in(latents)
conditioning = self.time_text_embed(timestep, image.dtype)
for block in self.transformer_blocks:
text, image = block(
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
image = self.norm_out(image, conditioning)
image = self.proj_out(image)
image = image + self.proj_latents_out(latents)
return image
@staticmethod
def state_dict_converter():
return QwenImageAccelerateAdapterStateDictConverter()
class QwenImageAccelerateAdapterStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
return state_dict

View File

@@ -1,10 +1,44 @@
import torch import torch, math
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional, Union, List from typing import Tuple, Optional, Union, List
from einops import rearrange from einops import rearrange
from .sd3_dit import TimestepEmbeddings, RMSNorm from .sd3_dit import TimestepEmbeddings, RMSNorm
from .flux_dit import AdaLayerNorm from .flux_dit import AdaLayerNorm
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
origin_dtype = q.dtype
q_std, k_std, v_std = q.std(), k.std(), v.std()
q, k, v = (q / q_std).to(torch.float8_e4m3fn), (k / k_std).to(torch.float8_e4m3fn), (v / v_std).to(torch.float8_e4m3fn)
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
v = rearrange(v, "b n s d -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=q_std * k_std / math.sqrt(q.size(-1)))
if isinstance(x, tuple):
x = x[0]
x = x.to(origin_dtype) * v_std
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x
class ApproximateGELU(nn.Module): class ApproximateGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True): def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
@@ -160,6 +194,7 @@ class QwenDoubleStreamAttention(nn.Module):
text: torch.FloatTensor, text: torch.FloatTensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
enable_fp8_attention: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -187,9 +222,7 @@ class QwenDoubleStreamAttention(nn.Module):
joint_k = torch.cat([txt_k, img_k], dim=2) joint_k = torch.cat([txt_k, img_k], dim=2)
joint_v = torch.cat([txt_v, img_v], dim=2) joint_v = torch.cat([txt_v, img_v], dim=2)
joint_attn_out = torch.nn.functional.scaled_dot_product_attention(joint_q, joint_k, joint_v, attn_mask=attention_mask) joint_attn_out = qwen_image_flash_attention(joint_q, joint_k, joint_v, num_heads=joint_q.shape[1], attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention).to(joint_q.dtype)
joint_attn_out = rearrange(joint_attn_out, 'b h s d -> b s (h d)').to(joint_q.dtype)
txt_attn_output = joint_attn_out[:, :seq_txt, :] txt_attn_output = joint_attn_out[:, :seq_txt, :]
img_attn_output = joint_attn_out[:, seq_txt:, :] img_attn_output = joint_attn_out[:, seq_txt:, :]
@@ -247,6 +280,7 @@ class QwenImageTransformerBlock(nn.Module):
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -263,6 +297,7 @@ class QwenImageTransformerBlock(nn.Module):
text=txt_modulated, text=txt_modulated,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
) )
image = image + img_gate * img_attn_out image = image + img_gate * img_attn_out

View File

@@ -6,6 +6,7 @@ from tqdm import tqdm
from einops import rearrange from einops import rearrange
from ..models import ModelManager, load_state_dict from ..models import ModelManager, load_state_dict
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE from ..models.qwen_image_vae import QwenImageVAE
@@ -27,12 +28,13 @@ class QwenImagePipeline(BasePipeline):
from transformers import Qwen2Tokenizer from transformers import Qwen2Tokenizer
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02) self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
self.accelerate_adapter: QwenImageAccelerateAdapter = None
self.text_encoder: QwenImageTextEncoder = None self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None self.vae: QwenImageVAE = None
self.tokenizer: Qwen2Tokenizer = None self.tokenizer: Qwen2Tokenizer = None
self.unit_runner = PipelineUnitRunner() self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit",) self.in_iteration_models = ("accelerate_adapter", "dit",)
self.units = [ self.units = [
QwenImageUnit_ShapeChecker(), QwenImageUnit_ShapeChecker(),
QwenImageUnit_NoiseInitializer(), QwenImageUnit_NoiseInitializer(),
@@ -63,14 +65,12 @@ class QwenImagePipeline(BasePipeline):
return loss return loss
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
self.vram_management_enabled = True self.vram_management_enabled = True
if num_persistent_param_in_dit is not None: if vram_limit is None:
vram_limit = None vram_limit = self.get_vram()
else: vram_limit = vram_limit - vram_buffer
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder is not None: if self.text_encoder is not None:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm
dtype = next(iter(self.text_encoder.parameters())).dtype dtype = next(iter(self.text_encoder.parameters())).dtype
@@ -96,31 +96,54 @@ class QwenImagePipeline(BasePipeline):
from ..models.qwen_image_dit import RMSNorm from ..models.qwen_image_dit import RMSNorm
dtype = next(iter(self.dit.parameters())).dtype dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device device = "cpu" if vram_limit is not None else self.device
enable_vram_management( if not enable_dit_fp8_computation:
self.dit, enable_vram_management(
module_map = { self.dit,
RMSNorm: AutoWrappedModule, module_map = {
torch.nn.Linear: AutoWrappedLinear, RMSNorm: AutoWrappedModule,
}, torch.nn.Linear: AutoWrappedLinear,
module_config = dict( },
offload_dtype=dtype, module_config = dict(
offload_device="cpu", offload_dtype=dtype,
onload_dtype=dtype, offload_device="cpu",
onload_device=device, onload_dtype=dtype,
computation_dtype=self.torch_dtype, onload_device=device,
computation_device=self.device, computation_dtype=self.torch_dtype,
), computation_device=self.device,
max_num_param=num_persistent_param_in_dit, ),
overflow_module_config = dict( vram_limit=vram_limit,
offload_dtype=dtype, )
offload_device="cpu", else:
onload_dtype=dtype, enable_vram_management(
onload_device="cpu", self.dit,
computation_dtype=self.torch_dtype, module_map = {
computation_device=self.device, RMSNorm: AutoWrappedModule,
), },
vram_limit=vram_limit, module_config = dict(
) offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae is not None: if self.vae is not None:
from ..models.qwen_image_vae import QwenImageRMS_norm from ..models.qwen_image_vae import QwenImageRMS_norm
dtype = next(iter(self.vae.parameters())).dtype dtype = next(iter(self.vae.parameters())).dtype
@@ -166,6 +189,7 @@ class QwenImagePipeline(BasePipeline):
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder") pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit") pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae") pipe.vae = model_manager.fetch_model("qwen_image_vae")
pipe.accelerate_adapter = model_manager.fetch_model("qwen_image_accelerate_adapter")
if tokenizer_config is not None and pipe.text_encoder is not None: if tokenizer_config is not None and pipe.text_encoder is not None:
tokenizer_config.download_if_necessary() tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer from transformers import Qwen2Tokenizer
@@ -195,6 +219,8 @@ class QwenImagePipeline(BasePipeline):
eligen_entity_prompts: list[str] = None, eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None, eligen_entity_masks: list[Image.Image] = None,
eligen_enable_on_negative: bool = False, eligen_enable_on_negative: bool = False,
# FP8
enable_fp8_attention: bool = False,
# Tile # Tile
tiled: bool = False, tiled: bool = False,
tile_size: int = 128, tile_size: int = 128,
@@ -217,6 +243,7 @@ class QwenImagePipeline(BasePipeline):
"input_image": input_image, "denoising_strength": denoising_strength, "input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
} }
@@ -409,6 +436,7 @@ class QwenImageUnit_EntityControl(PipelineUnit):
def model_fn_qwen_image( def model_fn_qwen_image(
dit: QwenImageDiT = None, dit: QwenImageDiT = None,
accelerate_adapter: QwenImageAccelerateAdapter = None,
latents=None, latents=None,
timestep=None, timestep=None,
prompt_emb=None, prompt_emb=None,
@@ -418,6 +446,7 @@ def model_fn_qwen_image(
entity_prompt_emb=None, entity_prompt_emb=None,
entity_prompt_emb_mask=None, entity_prompt_emb_mask=None,
entity_masks=None, entity_masks=None,
enable_fp8_attention=False,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs **kwargs
@@ -451,10 +480,14 @@ def model_fn_qwen_image(
temb=conditioning, temb=conditioning,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
) )
image = dit.norm_out(image, conditioning) if accelerate_adapter is not None:
image = dit.proj_out(image) image = accelerate_adapter(latents, image, text, image_rotary_emb, timestep)
else:
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return latents return latents

View File

@@ -31,11 +31,13 @@ class FlowMatchScheduler():
self.set_timesteps(num_inference_steps) self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None): def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, random_sigmas=False):
if shift is not None: if shift is not None:
self.shift = shift self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step: if random_sigmas:
self.sigmas = torch.Tensor(sorted([torch.rand((1,)).item() * (sigma_start - self.sigma_min) for i in range(num_inference_steps - 1)] + [sigma_start], reverse=True))
elif self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else: else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)

View File

@@ -110,8 +110,48 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_A_weights = [] self.lora_A_weights = []
self.lora_B_weights = [] self.lora_B_weights = []
self.lora_merger = None self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
bias = bias.to(torch.bfloat16)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def forward(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs):
# VRAM management
if self.state == 2: if self.state == 2:
weight, bias = self.weight, self.bias weight, bias = self.weight, self.bias
else: else:
@@ -123,8 +163,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
else: else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device) weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
out = torch.nn.functional.linear(x, weight, bias)
# Linear forward
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
# LoRA
if len(self.lora_A_weights) == 0: if len(self.lora_A_weights) == 0:
# No LoRA # No LoRA
return out return out

View File

@@ -164,6 +164,7 @@ After enabling VRAM management, the framework will automatically choose a memory
* `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not a strict limit. If the set limit is too low but actual free VRAM is enough, the model will run with minimal VRAM use. Set it to 0 for the smallest possible VRAM use. * `vram_limit`: VRAM usage limit in GB. By default, it uses all free VRAM on the device. Note that this is not a strict limit. If the set limit is too low but actual free VRAM is enough, the model will run with minimal VRAM use. Set it to 0 for the smallest possible VRAM use.
* `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because large network layers may use more VRAM than expected during loading. The best value is the VRAM size of the largest model layer. * `vram_buffer`: VRAM buffer size in GB. Default is 0.5GB. A buffer is needed because large network layers may use more VRAM than expected during loading. The best value is the VRAM size of the largest model layer.
* `num_persistent_param_in_dit`: Number of parameters to keep in VRAM in the DiT model. Default is no limit. This option will be removed in the future. Do not rely on it. * `num_persistent_param_in_dit`: Number of parameters to keep in VRAM in the DiT model. Default is no limit. This option will be removed in the future. Do not rely on it.
* `enable_dit_fp8_computation`: Whether to enable FP8 computation in the DiT model. This is only applicable to GPUs that support FP8 operations (e.g., H200, etc.). Disabled by default.
</details> </details>
@@ -172,7 +173,11 @@ After enabling VRAM management, the framework will automatically choose a memory
<summary>Inference Acceleration</summary> <summary>Inference Acceleration</summary>
Inference acceleration for Qwen-Image is under development. Please stay tuned! * FP8 Quantization: Choose the appropriate quantization method based on your hardware and requirements.
* GPUs that do not support FP8 computation (e.g., A100, 4090, etc.): FP8 quantization will only reduce VRAM usage without speeding up inference. Code: [./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPUs that support FP8 operations (e.g., H200, etc.): Please install [Flash Attention 3](https://github.com/Dao-AILab/flash-attention). Otherwise, FP8 acceleration will only apply to Linear layers.
* Faster inference but higher VRAM usage: Use [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* Slightly slower inference but lower VRAM usage: Use [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
</details> </details>

View File

@@ -164,6 +164,7 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
* `vram_limit`: 显存占用量限制GB默认占用设备上的剩余显存。注意这不是一个绝对限制当设置的显存不足以支持模型进行推理但实际可用显存足够时将会以最小化显存占用的形式进行推理。将其设置为0时将会实现理论最小显存占用。 * `vram_limit`: 显存占用量限制GB默认占用设备上的剩余显存。注意这不是一个绝对限制当设置的显存不足以支持模型进行推理但实际可用显存足够时将会以最小化显存占用的形式进行推理。将其设置为0时将会实现理论最小显存占用。
* `vram_buffer`: 显存缓冲区大小GB默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。 * `vram_buffer`: 显存缓冲区大小GB默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。 * `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
* `enable_dit_fp8_computation`: 是否启用 DiT 模型中的 FP8 计算,仅适用于支持 FP8 运算的 GPU例如 H200 等),默认不启用。
</details> </details>
@@ -172,7 +173,11 @@ FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在
<summary>推理加速</summary> <summary>推理加速</summary>
Qwen-Image 的推理加速技术正在开发中,敬请期待! * FP8 量化:根据您的硬件与需求,请选择合适的量化方式
* GPU 不支持 FP8 计算(例如 A100、4090 等FP8 量化仅能降低显存占用,无法加速,代码:[./model_inference_lor_vram/Qwen-Image.py](./model_inference_lor_vram/Qwen-Image.py)
* GPU 支持 FP8 运算(例如 H200 等):请安装 [Flash Attention 3](https://github.com/Dao-AILab/flash-attention),否则 FP8 加速仅对 Linear 层生效
* 更快的速度,但更大的显存:请使用 [./accelerate/Qwen-Image-FP8.py](./accelerate/Qwen-Image-FP8.py)
* 稍慢的速度,但更小的显存:请使用 [./accelerate/Qwen-Image-FP8-offload.py](./accelerate/Qwen-Image-FP8-offload.py)
</details> </details>

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management(enable_dit_fp8_computation=True)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
image.save("image.jpg")

View File

@@ -0,0 +1,51 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.models.qwen_image_dit import RMSNorm
from diffsynth.vram_management.layers import enable_vram_management, AutoWrappedLinear, AutoWrappedModule
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
enable_vram_management(
pipe.dit,
module_map = {
RMSNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=torch.bfloat16,
offload_device="cuda",
onload_dtype=torch.bfloat16,
onload_device="cuda",
computation_dtype=torch.bfloat16,
computation_device="cuda",
),
vram_limit=None,
)
enable_vram_management(
pipe.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=torch.float8_e4m3fn,
offload_device="cuda",
onload_dtype=torch.float8_e4m3fn,
onload_device="cuda",
computation_dtype=torch.float8_e4m3fn,
computation_device="cuda",
),
vram_limit=None,
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=40, enable_fp8_attention=True)
image.save("image.jpg")

View File

@@ -0,0 +1,32 @@
# This script is for initializing a Qwen-Image-Accelerate-Adapter
from diffsynth import load_state_dict, hash_state_dict_keys
from diffsynth.pipelines.qwen_image import QwenImageAccelerateAdapter
import torch
from safetensors.torch import save_file
state_dict_dit = {}
for i in range(1, 10):
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
adapter = QwenImageAccelerateAdapter().to(dtype=torch.bfloat16, device="cuda")
state_dict_adapter = adapter.state_dict()
state_dict_init = {}
for k in state_dict_adapter:
if k.startswith("transformer_blocks"):
name = k.replace("transformer_blocks.0.", "transformer_blocks.59.")
param = state_dict_dit[name]
if "_mod." in k:
param[2*3072: 3*3072] = 0
param[5*3072: 6*3072] = 0
state_dict_init[k] = param
elif k in state_dict_dit:
state_dict_init[k] = state_dict_dit[k]
else:
state_dict_init[k] = torch.zeros_like(state_dict_adapter[k])
print("Zero initialized:", k)
adapter.load_state_dict(state_dict_init)
print(hash_state_dict_keys(state_dict_init))
save_file(state_dict_init, "models/adapter.safetensors")

18
test_accelerate.py Normal file
View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig("models/adapter.safetensors")
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=4, cfg_scale=1)
image.save("image.jpg")