This commit is contained in:
Artiprocher
2026-01-08 13:21:33 +08:00
parent dd479e5bff
commit 8ba528a8f6
4 changed files with 22 additions and 11 deletions

View File

@@ -601,7 +601,9 @@ def model_fn_z_image_turbo(
if control_context is not None: if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy) kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner( refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1) dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
for layer_id, layer in enumerate(dit.noise_refiner): for layer_id, layer in enumerate(dit.noise_refiner):
x = gradient_checkpoint_forward( x = gradient_checkpoint_forward(
@@ -640,7 +642,9 @@ def model_fn_z_image_turbo(
if control_context is not None: if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy) kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
hints = controlnet.forward_layers( hints = controlnet.forward_layers(
unified, cap_feats, control_context, control_context_item_seqlens, kwargs) unified, cap_feats, control_context, control_context_item_seqlens, kwargs,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
for layer_id, layer in enumerate(dit.layers): for layer_id, layer in enumerate(dit.layers):
unified = gradient_checkpoint_forward( unified = gradient_checkpoint_forward(

View File

@@ -108,7 +108,14 @@ def test_flux():
run_inference("examples/flux/model_training/validate_lora") run_inference("examples/flux/model_training/validate_lora")
def test_z_image():
run_inference("examples/z_image/model_inference")
run_inference("examples/z_image/model_inference_low_vram")
run_train_multi_GPU("examples/z_image/model_training/full")
run_inference("examples/z_image/model_training/validate_full")
run_train_single_GPU("examples/z_image/model_training/lora")
run_inference("examples/z_image/model_training/validate_lora")
if __name__ == "__main__": if __name__ == "__main__":
test_qwen_image() test_z_image()
test_flux()
test_wan()

View File

@@ -37,11 +37,11 @@ pipe = ZImagePipeline.from_pretrained(
# Load images # Load images
snapshot_download( snapshot_download(
model_id="DiffSynth-Studio/Qwen-Image-i2L", model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L",
allow_file_pattern="assets/style/*", allow_file_pattern="assets/style/*",
local_dir="data/examples" local_dir="data/style_input"
) )
images = [Image.open(f"data/style/1/{i}.jpg") for i in range(5)] images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)]
# Image to LoRA # Image to LoRA
with torch.no_grad(): with torch.no_grad():

View File

@@ -37,11 +37,11 @@ pipe = ZImagePipeline.from_pretrained(
# Load images # Load images
snapshot_download( snapshot_download(
model_id="DiffSynth-Studio/Qwen-Image-i2L", model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L",
allow_file_pattern="assets/style/*", allow_file_pattern="assets/style/*",
local_dir="data/examples" local_dir="data/style_input"
) )
images = [Image.open(f"data/style/1/{i}.jpg") for i in range(5)] images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)]
# Image to LoRA # Image to LoRA
with torch.no_grad(): with torch.no_grad():