diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index cb676d7..94d98cc 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -623,11 +623,11 @@ def model_fn_qwen_image( timestep = timestep / 1000 image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2) + image_seq_len = image.shape[1] if edit_latents is not None: img_shapes += [(edit_latents.shape[0], edit_latents.shape[2]//2, edit_latents.shape[3]//2)] edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=edit_latents.shape[2]//2, W=edit_latents.shape[3]//2, P=2, Q=2) - image_seq_len = image.shape[1] image = torch.cat([image, edit_image], dim=1) image = dit.img_in(image) @@ -660,8 +660,8 @@ def model_fn_qwen_image( enable_fp8_attention=enable_fp8_attention, ) if blockwise_controlnet_conditioning is not None: - image = image + blockwise_controlnet.blockwise_forward( - image=image, conditionings=blockwise_controlnet_conditioning, + image[:, :image_seq_len] = image[:, :image_seq_len] + blockwise_controlnet.blockwise_forward( + image=image[:, :image_seq_len], conditionings=blockwise_controlnet_conditioning, controlnet_inputs=blockwise_controlnet_inputs, block_id=block_id, progress_id=progress_id, num_inference_steps=num_inference_steps, ) diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh b/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh index bf9fd89..ec25765 100644 --- a/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh +++ b/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh @@ -1,6 +1,8 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata_edit.csv \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ --max_pixels 1048576 \ --dataset_repeat 50 \ --model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh index d12363f..0662b1e 100644 --- a/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh @@ -1,6 +1,8 @@ accelerate launch examples/qwen_image/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata_edit.csv \ + --data_file_keys "image,edit_image" \ + --extra_inputs "edit_image" \ --max_pixels 1048576 \ --dataset_repeat 50 \ --model_id_with_origin_paths "Qwen/Qwen-Image-Edit:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \