From 8ba528a8f65a252d5cfff04d526494f94ef03e55 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 8 Jan 2026 13:21:33 +0800 Subject: [PATCH] bugfix --- diffsynth/pipelines/z_image.py | 8 ++++++-- examples/dev_tools/unit_test.py | 13 ++++++++++--- .../model_inference/Z-Image-Omni-Base-i2L.py | 6 +++--- .../Z-Image-Omni-Base-i2L.py | 6 +++--- 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index df6d0aa..9ba182a 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -601,7 +601,9 @@ def model_fn_z_image_turbo( if control_context is not None: kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy) refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner( - dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1) + dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) for layer_id, layer in enumerate(dit.noise_refiner): x = gradient_checkpoint_forward( @@ -640,7 +642,9 @@ def model_fn_z_image_turbo( if control_context is not None: kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy) hints = controlnet.forward_layers( - unified, cap_feats, control_context, control_context_item_seqlens, kwargs) + unified, cap_feats, control_context, control_context_item_seqlens, kwargs, + use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) for layer_id, layer in enumerate(dit.layers): unified = gradient_checkpoint_forward( diff --git a/examples/dev_tools/unit_test.py b/examples/dev_tools/unit_test.py index 364af47..200ced8 100644 --- a/examples/dev_tools/unit_test.py +++ b/examples/dev_tools/unit_test.py @@ -108,7 +108,14 @@ def test_flux(): run_inference("examples/flux/model_training/validate_lora") +def test_z_image(): + run_inference("examples/z_image/model_inference") + run_inference("examples/z_image/model_inference_low_vram") + run_train_multi_GPU("examples/z_image/model_training/full") + run_inference("examples/z_image/model_training/validate_full") + run_train_single_GPU("examples/z_image/model_training/lora") + run_inference("examples/z_image/model_training/validate_lora") + + if __name__ == "__main__": - test_qwen_image() - test_flux() - test_wan() + test_z_image() diff --git a/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py b/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py index 73e67d9..10d37ad 100644 --- a/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py +++ b/examples/z_image/model_inference/Z-Image-Omni-Base-i2L.py @@ -37,11 +37,11 @@ pipe = ZImagePipeline.from_pretrained( # Load images snapshot_download( - model_id="DiffSynth-Studio/Qwen-Image-i2L", + model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", allow_file_pattern="assets/style/*", - local_dir="data/examples" + local_dir="data/style_input" ) -images = [Image.open(f"data/style/1/{i}.jpg") for i in range(5)] +images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)] # Image to LoRA with torch.no_grad(): diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py index 62a7b31..7378ada 100644 --- a/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py +++ b/examples/z_image/model_inference_low_vram/Z-Image-Omni-Base-i2L.py @@ -37,11 +37,11 @@ pipe = ZImagePipeline.from_pretrained( # Load images snapshot_download( - model_id="DiffSynth-Studio/Qwen-Image-i2L", + model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", allow_file_pattern="assets/style/*", - local_dir="data/examples" + local_dir="data/style_input" ) -images = [Image.open(f"data/style/1/{i}.jpg") for i in range(5)] +images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)] # Image to LoRA with torch.no_grad():