mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 10:48:11 +00:00
bugfix
This commit is contained in:
@@ -601,7 +601,9 @@ def model_fn_z_image_turbo(
|
|||||||
if control_context is not None:
|
if control_context is not None:
|
||||||
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
|
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
|
||||||
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
|
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
|
||||||
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1)
|
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
|
||||||
for layer_id, layer in enumerate(dit.noise_refiner):
|
for layer_id, layer in enumerate(dit.noise_refiner):
|
||||||
x = gradient_checkpoint_forward(
|
x = gradient_checkpoint_forward(
|
||||||
@@ -640,7 +642,9 @@ def model_fn_z_image_turbo(
|
|||||||
if control_context is not None:
|
if control_context is not None:
|
||||||
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
|
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
|
||||||
hints = controlnet.forward_layers(
|
hints = controlnet.forward_layers(
|
||||||
unified, cap_feats, control_context, control_context_item_seqlens, kwargs)
|
unified, cap_feats, control_context, control_context_item_seqlens, kwargs,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
|
||||||
for layer_id, layer in enumerate(dit.layers):
|
for layer_id, layer in enumerate(dit.layers):
|
||||||
unified = gradient_checkpoint_forward(
|
unified = gradient_checkpoint_forward(
|
||||||
|
|||||||
@@ -108,7 +108,14 @@ def test_flux():
|
|||||||
run_inference("examples/flux/model_training/validate_lora")
|
run_inference("examples/flux/model_training/validate_lora")
|
||||||
|
|
||||||
|
|
||||||
|
def test_z_image():
|
||||||
|
run_inference("examples/z_image/model_inference")
|
||||||
|
run_inference("examples/z_image/model_inference_low_vram")
|
||||||
|
run_train_multi_GPU("examples/z_image/model_training/full")
|
||||||
|
run_inference("examples/z_image/model_training/validate_full")
|
||||||
|
run_train_single_GPU("examples/z_image/model_training/lora")
|
||||||
|
run_inference("examples/z_image/model_training/validate_lora")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_qwen_image()
|
test_z_image()
|
||||||
test_flux()
|
|
||||||
test_wan()
|
|
||||||
|
|||||||
@@ -37,11 +37,11 @@ pipe = ZImagePipeline.from_pretrained(
|
|||||||
|
|
||||||
# Load images
|
# Load images
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
model_id="DiffSynth-Studio/Qwen-Image-i2L",
|
model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L",
|
||||||
allow_file_pattern="assets/style/*",
|
allow_file_pattern="assets/style/*",
|
||||||
local_dir="data/examples"
|
local_dir="data/style_input"
|
||||||
)
|
)
|
||||||
images = [Image.open(f"data/style/1/{i}.jpg") for i in range(5)]
|
images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)]
|
||||||
|
|
||||||
# Image to LoRA
|
# Image to LoRA
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
@@ -37,11 +37,11 @@ pipe = ZImagePipeline.from_pretrained(
|
|||||||
|
|
||||||
# Load images
|
# Load images
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
model_id="DiffSynth-Studio/Qwen-Image-i2L",
|
model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L",
|
||||||
allow_file_pattern="assets/style/*",
|
allow_file_pattern="assets/style/*",
|
||||||
local_dir="data/examples"
|
local_dir="data/style_input"
|
||||||
)
|
)
|
||||||
images = [Image.open(f"data/style/1/{i}.jpg") for i in range(5)]
|
images = [Image.open(f"data/style_input/assets/style/1/{i}.jpg") for i in range(6)]
|
||||||
|
|
||||||
# Image to LoRA
|
# Image to LoRA
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Reference in New Issue
Block a user