add inpaint mask in qwen-image

This commit is contained in:
Artiprocher
2025-08-18 15:16:38 +08:00
parent ac931856d5
commit 7ed09bb78d
4 changed files with 47 additions and 8 deletions

View File

@@ -66,6 +66,7 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_ShapeChecker(), QwenImageUnit_ShapeChecker(),
QwenImageUnit_NoiseInitializer(), QwenImageUnit_NoiseInitializer(),
QwenImageUnit_InputImageEmbedder(), QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_Inpaint(),
QwenImageUnit_PromptEmbedder(), QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(), QwenImageUnit_EntityControl(),
QwenImageUnit_BlockwiseControlNet(), QwenImageUnit_BlockwiseControlNet(),
@@ -252,6 +253,10 @@ class QwenImagePipeline(BasePipeline):
# Image # Image
input_image: Image.Image = None, input_image: Image.Image = None,
denoising_strength: float = 1.0, denoising_strength: float = 1.0,
# Inpaint
inpaint_mask: Image.Image = None,
inpaint_blur_size: int = None,
inpaint_blur_sigma: float = None,
# Shape # Shape
height: int = 1328, height: int = 1328,
width: int = 1328, width: int = 1328,
@@ -288,6 +293,7 @@ class QwenImagePipeline(BasePipeline):
inputs_shared = { inputs_shared = {
"cfg_scale": cfg_scale, "cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength, "input_image": input_image, "denoising_strength": denoising_strength,
"inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"enable_fp8_attention": enable_fp8_attention, "enable_fp8_attention": enable_fp8_attention,
@@ -314,7 +320,7 @@ class QwenImagePipeline(BasePipeline):
noise_pred = noise_pred_posi noise_pred = noise_pred_posi
# Scheduler # Scheduler
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode # Decode
self.load_models_to_device(['vae']) self.load_models_to_device(['vae'])
@@ -363,7 +369,26 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
return {"latents": noise, "input_latents": input_latents} return {"latents": noise, "input_latents": input_latents}
else: else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": None} return {"latents": latents, "input_latents": input_latents}
class QwenImageUnit_Inpaint(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
)
def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
if inpaint_mask is None:
return {}
inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1)
inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)
if inpaint_blur_size is not None and inpaint_blur_sigma is not None:
from torchvision.transforms import GaussianBlur
blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma)
inpaint_mask = blur(inpaint_mask)
return {"inpaint_mask": inpaint_mask}

View File

@@ -139,6 +139,20 @@ class BasePipeline(torch.nn.Module):
else: else:
model.eval() model.eval()
model.requires_grad_(False) model.requires_grad_(False)
def blend_with_mask(self, base, addition, mask):
return base * (1 - mask) + addition * mask
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
timestep = scheduler.timesteps[progress_id]
if inpaint_mask is not None:
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
latents_next = scheduler.step(noise_pred, timestep, latents)
return latents_next
@dataclass @dataclass

View File

@@ -22,12 +22,12 @@ dataset_snapshot_download(
allow_file_pattern="inpaint/*.jpg" allow_file_pattern="inpaint/*.jpg"
) )
prompt = "a cat with sunglasses" prompt = "a cat with sunglasses"
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
image = pipe( image = pipe(
prompt, seed=0, prompt, seed=0,
input_image=controlnet_image, inpaint_mask=inpaint_mask,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)], blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
height=1024, width=1024,
num_inference_steps=40, num_inference_steps=40,
) )
image.save("image.jpg") image.save("image.jpg")

View File

@@ -23,12 +23,12 @@ dataset_snapshot_download(
allow_file_pattern="inpaint/*.jpg" allow_file_pattern="inpaint/*.jpg"
) )
prompt = "a cat with sunglasses" prompt = "a cat with sunglasses"
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024)) controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024)) inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
image = pipe( image = pipe(
prompt, seed=0, prompt, seed=0,
input_image=controlnet_image, inpaint_mask=inpaint_mask,
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)], blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
height=1024, width=1024,
num_inference_steps=40, num_inference_steps=40,
) )
image.save("image.jpg") image.save("image.jpg")