diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 3b529b2..4e9cf05 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -66,6 +66,7 @@ class QwenImagePipeline(BasePipeline): QwenImageUnit_ShapeChecker(), QwenImageUnit_NoiseInitializer(), QwenImageUnit_InputImageEmbedder(), + QwenImageUnit_Inpaint(), QwenImageUnit_PromptEmbedder(), QwenImageUnit_EntityControl(), QwenImageUnit_BlockwiseControlNet(), @@ -252,6 +253,10 @@ class QwenImagePipeline(BasePipeline): # Image input_image: Image.Image = None, denoising_strength: float = 1.0, + # Inpaint + inpaint_mask: Image.Image = None, + inpaint_blur_size: int = None, + inpaint_blur_sigma: float = None, # Shape height: int = 1328, width: int = 1328, @@ -288,6 +293,7 @@ class QwenImagePipeline(BasePipeline): inputs_shared = { "cfg_scale": cfg_scale, "input_image": input_image, "denoising_strength": denoising_strength, + "inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma, "height": height, "width": width, "seed": seed, "rand_device": rand_device, "enable_fp8_attention": enable_fp8_attention, @@ -314,7 +320,7 @@ class QwenImagePipeline(BasePipeline): noise_pred = noise_pred_posi # Scheduler - inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) # Decode self.load_models_to_device(['vae']) @@ -363,7 +369,26 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit): return {"latents": noise, "input_latents": input_latents} else: latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) - return {"latents": latents, "input_latents": None} + return {"latents": latents, "input_latents": input_latents} + + + +class QwenImageUnit_Inpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"), + ) + + def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma): + if inpaint_mask is None: + return {} + inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1) + inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True) + if inpaint_blur_size is not None and inpaint_blur_sigma is not None: + from torchvision.transforms import GaussianBlur + blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma) + inpaint_mask = blur(inpaint_mask) + return {"inpaint_mask": inpaint_mask} diff --git a/diffsynth/utils/__init__.py b/diffsynth/utils/__init__.py index 97f3926..ec3c727 100644 --- a/diffsynth/utils/__init__.py +++ b/diffsynth/utils/__init__.py @@ -139,6 +139,20 @@ class BasePipeline(torch.nn.Module): else: model.eval() model.requires_grad_(False) + + + def blend_with_mask(self, base, addition, mask): + return base * (1 - mask) + addition * mask + + + def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs): + timestep = scheduler.timesteps[progress_id] + if inpaint_mask is not None: + noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents) + noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask) + latents_next = scheduler.step(noise_pred, timestep, latents) + return latents_next + @dataclass diff --git a/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py index b03c19d..1cb98e0 100644 --- a/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py +++ b/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py @@ -22,12 +22,12 @@ dataset_snapshot_download( allow_file_pattern="inpaint/*.jpg" ) prompt = "a cat with sunglasses" -controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) -inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328)) image = pipe( prompt, seed=0, + input_image=controlnet_image, inpaint_mask=inpaint_mask, blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)], - height=1024, width=1024, num_inference_steps=40, ) image.save("image.jpg") diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py index 7d74c1e..0989932 100644 --- a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py @@ -23,12 +23,12 @@ dataset_snapshot_download( allow_file_pattern="inpaint/*.jpg" ) prompt = "a cat with sunglasses" -controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) -inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) +controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328)) +inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328)) image = pipe( prompt, seed=0, + input_image=controlnet_image, inpaint_mask=inpaint_mask, blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)], - height=1024, width=1024, num_inference_steps=40, ) image.save("image.jpg")