From e85f42b474ffa5d79a9c1fee2c22fc774d3e6478 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 8 Aug 2025 16:51:42 +0800 Subject: [PATCH] qwen-image-acc-adapter --- diffsynth/configs/model_config.py | 2 + .../models/qwen_image_accelerate_adapter.py | 63 +++++++++++++++++++ diffsynth/pipelines/qwen_image.py | 17 +++-- diffsynth/schedulers/flow_match.py | 6 +- .../others/initialize_accelerate_adapter.py | 32 ++++++++++ test_accelerate.py | 18 ++++++ 6 files changed, 131 insertions(+), 7 deletions(-) create mode 100644 diffsynth/models/qwen_image_accelerate_adapter.py create mode 100644 examples/qwen_image/model_training/full/others/initialize_accelerate_adapter.py create mode 100644 test_accelerate.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index e328593..96b3142 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -72,6 +72,7 @@ from ..models.flux_lora_encoder import FluxLoRAEncoder from ..models.nexus_gen_projector import NexusGenAdapter, NexusGenImageEmbeddingMerger from ..models.nexus_gen import NexusGenAutoregressiveModel +from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_vae import QwenImageVAE @@ -165,6 +166,7 @@ model_loader_configs = [ (None, "63c969fd37cce769a90aa781fbff5f81", ["flux_dit", "nexus_gen_editing_adapter"], [FluxDiT, NexusGenImageEmbeddingMerger], "civitai"), (None, "2bd19e845116e4f875a0a048e27fc219", ["nexus_gen_llm"], [NexusGenAutoregressiveModel], "civitai"), (None, "0319a1cb19835fb510907dd3367c95ff", ["qwen_image_dit"], [QwenImageDiT], "civitai"), + (None, "ae9d13bfc578702baf6445d2cf3d1d46", ["qwen_image_accelerate_adapter"], [QwenImageAccelerateAdapter], "civitai"), (None, "8004730443f55db63092006dd9f7110e", ["qwen_image_text_encoder"], [QwenImageTextEncoder], "diffusers"), (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"), ] diff --git a/diffsynth/models/qwen_image_accelerate_adapter.py b/diffsynth/models/qwen_image_accelerate_adapter.py new file mode 100644 index 0000000..810f3eb --- /dev/null +++ b/diffsynth/models/qwen_image_accelerate_adapter.py @@ -0,0 +1,63 @@ +from .qwen_image_dit import QwenImageTransformerBlock, AdaLayerNorm, TimestepEmbeddings +from einops import rearrange +import torch + + + +class QwenImageAccelerateAdapter(torch.nn.Module): + def __init__( + self, + num_layers: int = 1, + ): + super().__init__() + self.proj_latents_in = torch.nn.Linear(64, 3072) + self.time_text_embed = TimestepEmbeddings(256, 3072, diffusers_compatible_format=True, scale=1000, align_dtype_to_timestep=True) + self.transformer_blocks = torch.nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=3072, + num_attention_heads=24, + attention_head_dim=128, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = AdaLayerNorm(3072, single=True) + self.proj_out = torch.nn.Linear(3072, 64) + self.proj_latents_out = torch.nn.Linear(64, 64) + + def forward( + self, + latents=None, + image=None, + text=None, + image_rotary_emb=None, + timestep=None, + ): + latents = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + image = image + self.proj_latents_in(latents) + conditioning = self.time_text_embed(timestep, image.dtype) + for block in self.transformer_blocks: + text, image = block( + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + ) + image = self.norm_out(image, conditioning) + image = self.proj_out(image) + image = image + self.proj_latents_out(latents) + return image + + @staticmethod + def state_dict_converter(): + return QwenImageAccelerateAdapterStateDictConverter() + + + +class QwenImageAccelerateAdapterStateDictConverter(): + def __init__(self): + pass + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 3e952c0..74426cd 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -6,6 +6,7 @@ from tqdm import tqdm from einops import rearrange from ..models import ModelManager, load_state_dict +from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter from ..models.qwen_image_dit import QwenImageDiT from ..models.qwen_image_text_encoder import QwenImageTextEncoder from ..models.qwen_image_vae import QwenImageVAE @@ -27,12 +28,13 @@ class QwenImagePipeline(BasePipeline): from transformers import Qwen2Tokenizer self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02) + self.accelerate_adapter: QwenImageAccelerateAdapter = None self.text_encoder: QwenImageTextEncoder = None self.dit: QwenImageDiT = None self.vae: QwenImageVAE = None self.tokenizer: Qwen2Tokenizer = None self.unit_runner = PipelineUnitRunner() - self.in_iteration_models = ("dit",) + self.in_iteration_models = ("accelerate_adapter", "dit",) self.units = [ QwenImageUnit_ShapeChecker(), QwenImageUnit_NoiseInitializer(), @@ -187,6 +189,7 @@ class QwenImagePipeline(BasePipeline): pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder") pipe.dit = model_manager.fetch_model("qwen_image_dit") pipe.vae = model_manager.fetch_model("qwen_image_vae") + pipe.accelerate_adapter = model_manager.fetch_model("qwen_image_accelerate_adapter") if tokenizer_config is not None and pipe.text_encoder is not None: tokenizer_config.download_if_necessary() from transformers import Qwen2Tokenizer @@ -226,7 +229,7 @@ class QwenImagePipeline(BasePipeline): progress_bar_cmd = tqdm, ): # Scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16)) + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), random_sigmas=True) # Parameters inputs_posi = { @@ -433,6 +436,7 @@ class QwenImageUnit_EntityControl(PipelineUnit): def model_fn_qwen_image( dit: QwenImageDiT = None, + accelerate_adapter: QwenImageAccelerateAdapter = None, latents=None, timestep=None, prompt_emb=None, @@ -478,9 +482,12 @@ def model_fn_qwen_image( attention_mask=attention_mask, enable_fp8_attention=enable_fp8_attention, ) - - image = dit.norm_out(image, conditioning) - image = dit.proj_out(image) + + if accelerate_adapter is not None: + image = accelerate_adapter(latents, image, text, image_rotary_emb, timestep) + else: + image = dit.norm_out(image, conditioning) + image = dit.proj_out(image) latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2) return latents diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index 6a8e235..2cde08f 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -31,11 +31,13 @@ class FlowMatchScheduler(): self.set_timesteps(num_inference_steps) - def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None): + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, random_sigmas=False): if shift is not None: self.shift = shift sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength - if self.extra_one_step: + if random_sigmas: + self.sigmas = torch.Tensor(sorted([torch.rand((1,)).item() * (sigma_start - self.sigma_min) for i in range(num_inference_steps - 1)] + [sigma_start], reverse=True)) + elif self.extra_one_step: self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] else: self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) diff --git a/examples/qwen_image/model_training/full/others/initialize_accelerate_adapter.py b/examples/qwen_image/model_training/full/others/initialize_accelerate_adapter.py new file mode 100644 index 0000000..4a032da --- /dev/null +++ b/examples/qwen_image/model_training/full/others/initialize_accelerate_adapter.py @@ -0,0 +1,32 @@ +# This script is for initializing a Qwen-Image-Accelerate-Adapter +from diffsynth import load_state_dict, hash_state_dict_keys +from diffsynth.pipelines.qwen_image import QwenImageAccelerateAdapter +import torch +from safetensors.torch import save_file + + +state_dict_dit = {} +for i in range(1, 10): + state_dict_dit.update(load_state_dict(f"models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-0000{i}-of-00009.safetensors", torch_dtype=torch.bfloat16, device="cuda")) + +adapter = QwenImageAccelerateAdapter().to(dtype=torch.bfloat16, device="cuda") +state_dict_adapter = adapter.state_dict() + +state_dict_init = {} +for k in state_dict_adapter: + if k.startswith("transformer_blocks"): + name = k.replace("transformer_blocks.0.", "transformer_blocks.59.") + param = state_dict_dit[name] + if "_mod." in k: + param[2*3072: 3*3072] = 0 + param[5*3072: 6*3072] = 0 + state_dict_init[k] = param + elif k in state_dict_dit: + state_dict_init[k] = state_dict_dit[k] + else: + state_dict_init[k] = torch.zeros_like(state_dict_adapter[k]) + print("Zero initialized:", k) +adapter.load_state_dict(state_dict_init) + +print(hash_state_dict_keys(state_dict_init)) +save_file(state_dict_init, "models/adapter.safetensors") \ No newline at end of file diff --git a/test_accelerate.py b/test_accelerate.py new file mode 100644 index 0000000..c6fc80a --- /dev/null +++ b/test_accelerate.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig("models/adapter.safetensors") + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=4, cfg_scale=1) +image.save("image.jpg")