Compare commits

..

3 Commits

Author SHA1 Message Date
Zhongjie Duan
17714a8cc8 Merge branch 'main' into fp8 2025-08-07 16:40:44 +08:00
Artiprocher
a947459bda refine README 2025-08-07 16:32:01 +08:00
Artiprocher
a0eec8c673 support qwen-image-fp8 2025-08-07 16:20:50 +08:00
7 changed files with 9 additions and 135 deletions

View File

@@ -72,7 +72,6 @@ from ..models.flux_lora_encoder import FluxLoRAEncoder
from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger
from ..models.nexus_gen import NexusGenAutoregressiveModel
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
@@ -166,7 +165,6 @@ model_loader_configs = [
(None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"),
(None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"),
(None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"),
(None, "ae9d13bfc578702baf6445d2cf3d1d46", ["qwen_image_accelerate_adapter"], [QwenImageAccelerateAdapter], "civitai"),
(None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"),
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
]

View File

@@ -1,63 +0,0 @@
from .qwen_image_dit import QwenImageTransformerBlock, AdaLayerNorm, TimestepEmbeddings
from einops import rearrange
import torch
class QwenImageAccelerateAdapter(torch.nn.Module):
def __init__(
self,
num_layers: int = 1,
):
super().__init__()
self.proj_latents_in = torch.nn.Linear(64, 3072)
self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True)
self.transformer_blocks = torch.nn.ModuleList(
[
QwenImageTransformerBlock(
dim=3072,
num_attention_heads=24,
attention_head_dim=128,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNorm(3072, single=True)
self.proj_out = torch.nn.Linear(3072, 64)
self.proj_latents_out = torch.nn.Linear(64, 64)
def forward(
self,
latents=None,
image=None,
text=None,
image_rotary_emb=None,
timestep=None,
):
latents = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
image = image + self.proj_latents_in(latents)
conditioning = self.time_text_embed(timestep, image.dtype)
for block in self.transformer_blocks:
text, image = block(
image=image,
text=text,
temb=conditioning,
image_rotary_emb=image_rotary_emb,
)
image = self.norm_out(image, conditioning)
image = self.proj_out(image)
image = image + self.proj_latents_out(latents)
return image
@staticmethod
def state_dict_converter():
return QwenImageAccelerateAdapterStateDictConverter()
class QwenImageAccelerateAdapterStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
return state_dict

View File

@@ -13,7 +13,7 @@ except ModuleNotFoundError:
def qwen_image_flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, attention_mask = None, enable_fp8_attention: bool = False):
if FLASH_ATTN_3_AVAILABLE and attention_mask is None:
if FLASH_ATTN_3_AVAILABLE:
if not enable_fp8_attention:
q = rearrange(q, "b n s d -> b s n d", n=num_heads)
k = rearrange(k, "b n s d -> b s n d", n=num_heads)
@@ -296,8 +296,8 @@ class QwenImageTransformerBlock(nn.Module):
image=img_modulated,
text=txt_modulated,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
attention_mask: Optional[torch.Tensor] = None,
enable_fp8_attention = False,
)
image = image + img_gate * img_attn_out

View File

@@ -6,7 +6,6 @@ from tqdm import tqdm
from einops import rearrange
from ..models import ModelManager, load_state_dict
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
@@ -28,13 +27,12 @@ class QwenImagePipeline(BasePipeline):
from transformers import Qwen2Tokenizer
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
self.accelerate_adapter: QwenImageAccelerateAdapter = None
self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None
self.tokenizer: Qwen2Tokenizer = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("accelerate_adapter", "dit",)
self.in_iteration_models = ("dit",)
self.units = [
QwenImageUnit_ShapeChecker(),
QwenImageUnit_NoiseInitializer(),
@@ -189,7 +187,6 @@ class QwenImagePipeline(BasePipeline):
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae")
pipe.accelerate_adapter = model_manager.fetch_model("qwen_image_accelerate_adapter")
if tokenizer_config is not None and pipe.text_encoder is not None:
tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer
@@ -219,8 +216,6 @@ class QwenImagePipeline(BasePipeline):
eligen_entity_prompts: list[str] = None,
eligen_entity_masks: list[Image.Image] = None,
eligen_enable_on_negative: bool = False,
# FP8
enable_fp8_attention: bool = False,
# Tile
tiled: bool = False,
tile_size: int = 128,
@@ -436,7 +431,6 @@ class QwenImageUnit_EntityControl(PipelineUnit):
def model_fn_qwen_image(
dit: QwenImageDiT = None,
accelerate_adapter: QwenImageAccelerateAdapter = None,
latents=None,
timestep=None,
prompt_emb=None,
@@ -482,12 +476,9 @@ def model_fn_qwen_image(
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
)
if accelerate_adapter is not None:
image = accelerate_adapter(latents, image, text, image_rotary_emb, timestep)
else:
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return latents

View File

@@ -31,13 +31,11 @@ class FlowMatchScheduler():
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, random_sigmas=False):
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None):
if shift is not None:
self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if random_sigmas:
self.sigmas = torch.Tensor(sorted([torch.rand((1,)).item() * (sigma_start - self.sigma_min) for i in range(num_inference_steps - 1)] + [sigma_start], reverse=True))
elif self.extra_one_step:
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)

View File

@@ -1,32 +0,0 @@
# This script is for initializing a Qwen-Image-Accelerate-Adapter
from diffsynth import load_state_dict, hash_state_dict_keys
from diffsynth.pipelines.qwen_image import QwenImageAccelerateAdapter
import torch
from safetensors.torch import save_file
state_dict_dit = {}
for i in range(1, 10):
state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda"))
adapter = QwenImageAccelerateAdapter().to(dtype=torch.bfloat16, device="cuda")
state_dict_adapter = adapter.state_dict()
state_dict_init = {}
for k in state_dict_adapter:
if k.startswith("transformer_blocks"):
name = k.replace("transformer_blocks.0.", "transformer_blocks.59.")
param = state_dict_dit[name]
if "_mod." in k:
param[2*3072: 3*3072] = 0
param[5*3072: 6*3072] = 0
state_dict_init[k] = param
elif k in state_dict_dit:
state_dict_init[k] = state_dict_dit[k]
else:
state_dict_init[k] = torch.zeros_like(state_dict_adapter[k])
print("Zero initialized:", k)
adapter.load_state_dict(state_dict_init)
print(hash_state_dict_keys(state_dict_init))
save_file(state_dict_init, "models/adapter.safetensors")

View File

@@ -1,18 +0,0 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
ModelConfig("models/adapter.safetensors")
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=4, cfg_scale=1)
image.save("image.jpg")