qwen-image-acc-adapter

This commit is contained in:
Artiprocher
2025-08-08 16:51:42 +08:00
parent 32cf5d32ce
commit e85f42b474
6 changed files with 131 additions and 7 deletions

View File

@@ -6,6 +6,7 @@ from tqdm import tqdm
from einops import rearrange
from ..models import ModelManager, load_state_dict
from ..models.qwen_image_accelerate_adapter import QwenImageAccelerateAdapter
from ..models.qwen_image_dit import QwenImageDiT
from ..models.qwen_image_text_encoder import QwenImageTextEncoder
from ..models.qwen_image_vae import QwenImageVAE
@@ -27,12 +28,13 @@ class QwenImagePipeline(BasePipeline):
from transformers import Qwen2Tokenizer
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
self.accelerate_adapter: QwenImageAccelerateAdapter = None
self.text_encoder: QwenImageTextEncoder = None
self.dit: QwenImageDiT = None
self.vae: QwenImageVAE = None
self.tokenizer: Qwen2Tokenizer = None
self.unit_runner = PipelineUnitRunner()
self.in_iteration_models = ("dit",)
self.in_iteration_models = ("accelerate_adapter", "dit",)
self.units = [
QwenImageUnit_ShapeChecker(),
QwenImageUnit_NoiseInitializer(),
@@ -187,6 +189,7 @@ class QwenImagePipeline(BasePipeline):
pipe.text_encoder = model_manager.fetch_model("qwen_image_text_encoder")
pipe.dit = model_manager.fetch_model("qwen_image_dit")
pipe.vae = model_manager.fetch_model("qwen_image_vae")
pipe.accelerate_adapter = model_manager.fetch_model("qwen_image_accelerate_adapter")
if tokenizer_config is not None and pipe.text_encoder is not None:
tokenizer_config.download_if_necessary()
from transformers import Qwen2Tokenizer
@@ -226,7 +229,7 @@ class QwenImagePipeline(BasePipeline):
progress_bar_cmd = tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), random_sigmas=True)
# Parameters
inputs_posi = {
@@ -433,6 +436,7 @@ class QwenImageUnit_EntityControl(PipelineUnit):
def model_fn_qwen_image(
dit: QwenImageDiT = None,
accelerate_adapter: QwenImageAccelerateAdapter = None,
latents=None,
timestep=None,
prompt_emb=None,
@@ -478,9 +482,12 @@ def model_fn_qwen_image(
attention_mask=attention_mask,
enable_fp8_attention=enable_fp8_attention,
)
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
if accelerate_adapter is not None:
image = accelerate_adapter(latents, image, text, image_rotary_emb, timestep)
else:
image = dit.norm_out(image, conditioning)
image = dit.proj_out(image)
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
return latents