diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index d9da07c..b17b412 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -278,8 +278,9 @@ class QwenImagePipeline(BasePipeline): eligen_entity_prompts: list[str] = None, eligen_entity_masks: list[Image.Image] = None, eligen_enable_on_negative: bool = False, - # Edit Image + # Qwen-Image-Edit edit_image: Image.Image = None, + edit_image_auto_resize: bool = True, edit_rope_interpolation: bool = False, # FP8 enable_fp8_attention: bool = False, @@ -311,7 +312,7 @@ class QwenImagePipeline(BasePipeline): "blockwise_controlnet_inputs": blockwise_controlnet_inputs, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, - "edit_image": edit_image, "edit_rope_interpolation": edit_rope_interpolation, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -584,17 +585,33 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit): class QwenImageUnit_EditImageEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("edit_image", "tiled", "tile_size", "tile_stride"), + input_params=("edit_image", "tiled", "tile_size", "tile_stride", "edit_image_auto_resize"), onload_model_names=("vae",) ) - def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride): + + def calculate_dimensions(self, target_area, ratio): + import math + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + + def edit_image_auto_resize(self, edit_image): + calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) + return edit_image.resize((calculated_width, calculated_height)) + + + def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride, edit_image_auto_resize=False): if edit_image is None: return {} + resized_edit_image = self.edit_image_auto_resize(edit_image) if edit_image_auto_resize else edit_image pipe.load_models_to_device(['vae']) - edit_image = pipe.preprocess_image(edit_image).to(device=pipe.device, dtype=pipe.torch_dtype) + edit_image = pipe.preprocess_image(resized_edit_image).to(device=pipe.device, dtype=pipe.torch_dtype) edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - return {"edit_latents": edit_latents} + return {"edit_latents": edit_latents, "edit_image": resized_edit_image} def model_fn_qwen_image( diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py b/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py index df373d8..bbe3b7e 100644 --- a/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py +++ b/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py @@ -22,5 +22,5 @@ image.save("image.jpg") prompt = "将裙子变成粉色" image = image.resize((512, 384)) -image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True) +image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False) image.save(f"image2.jpg") diff --git a/examples/qwen_image/model_inference/Qwen-Image-Edit.py b/examples/qwen_image/model_inference/Qwen-Image-Edit.py index e67d744..38a412b 100644 --- a/examples/qwen_image/model_inference/Qwen-Image-Edit.py +++ b/examples/qwen_image/model_inference/Qwen-Image-Edit.py @@ -13,9 +13,14 @@ pipe = QwenImagePipeline.from_pretrained( processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"), ) prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" -image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1024, width=1024) -image.save("image1.jpg") +input_image = pipe(prompt=prompt, seed=0, num_inference_steps=40, height=1328, width=1024) +input_image.save("image1.jpg") prompt = "将裙子改为粉色" -image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=1024) +# edit_image_auto_resize=True: auto resize input image to match the area of 1024*1024 with the original aspect ratio +# edit_image_auto_resize=False: do not resize input image +image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=True) image.save(f"image2.jpg") + +image = pipe(prompt, edit_image=input_image, seed=1, num_inference_steps=40, height=1328, width=1024, edit_image_auto_resize=False) +image.save(f"image3.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py index 496f80e..55c771f 100644 --- a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py @@ -24,5 +24,5 @@ image.save("image.jpg") prompt = "将裙子变成粉色" image = image.resize((512, 384)) -image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True) +image = pipe(prompt, edit_image=image, seed=1, num_inference_steps=40, height=1024, width=768, edit_rope_interpolation=True, edit_image_auto_resize=False) image.save(f"image2.jpg") diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 31bbfda..7418661 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -78,6 +78,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): "rand_device": self.pipe.device, "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "edit_image_auto_resize": True, } # Extra inputs