mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
add inpaint mask in qwen-image
This commit is contained in:
@@ -66,6 +66,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
QwenImageUnit_ShapeChecker(),
|
QwenImageUnit_ShapeChecker(),
|
||||||
QwenImageUnit_NoiseInitializer(),
|
QwenImageUnit_NoiseInitializer(),
|
||||||
QwenImageUnit_InputImageEmbedder(),
|
QwenImageUnit_InputImageEmbedder(),
|
||||||
|
QwenImageUnit_Inpaint(),
|
||||||
QwenImageUnit_PromptEmbedder(),
|
QwenImageUnit_PromptEmbedder(),
|
||||||
QwenImageUnit_EntityControl(),
|
QwenImageUnit_EntityControl(),
|
||||||
QwenImageUnit_BlockwiseControlNet(),
|
QwenImageUnit_BlockwiseControlNet(),
|
||||||
@@ -252,6 +253,10 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
# Image
|
# Image
|
||||||
input_image: Image.Image = None,
|
input_image: Image.Image = None,
|
||||||
denoising_strength: float = 1.0,
|
denoising_strength: float = 1.0,
|
||||||
|
# Inpaint
|
||||||
|
inpaint_mask: Image.Image = None,
|
||||||
|
inpaint_blur_size: int = None,
|
||||||
|
inpaint_blur_sigma: float = None,
|
||||||
# Shape
|
# Shape
|
||||||
height: int = 1328,
|
height: int = 1328,
|
||||||
width: int = 1328,
|
width: int = 1328,
|
||||||
@@ -288,6 +293,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
inputs_shared = {
|
inputs_shared = {
|
||||||
"cfg_scale": cfg_scale,
|
"cfg_scale": cfg_scale,
|
||||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
|
"inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
|
||||||
"height": height, "width": width,
|
"height": height, "width": width,
|
||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device,
|
||||||
"enable_fp8_attention": enable_fp8_attention,
|
"enable_fp8_attention": enable_fp8_attention,
|
||||||
@@ -314,7 +320,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
@@ -363,7 +369,26 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
|||||||
return {"latents": noise, "input_latents": input_latents}
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
else:
|
else:
|
||||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
return {"latents": latents, "input_latents": None}
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageUnit_Inpaint(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: QwenImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
|
||||||
|
if inpaint_mask is None:
|
||||||
|
return {}
|
||||||
|
inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 8, height // 8)), min_value=0, max_value=1)
|
||||||
|
inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)
|
||||||
|
if inpaint_blur_size is not None and inpaint_blur_sigma is not None:
|
||||||
|
from torchvision.transforms import GaussianBlur
|
||||||
|
blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma)
|
||||||
|
inpaint_mask = blur(inpaint_mask)
|
||||||
|
return {"inpaint_mask": inpaint_mask}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -139,6 +139,20 @@ class BasePipeline(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
model.eval()
|
model.eval()
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
|
|
||||||
|
|
||||||
|
def blend_with_mask(self, base, addition, mask):
|
||||||
|
return base * (1 - mask) + addition * mask
|
||||||
|
|
||||||
|
|
||||||
|
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
||||||
|
timestep = scheduler.timesteps[progress_id]
|
||||||
|
if inpaint_mask is not None:
|
||||||
|
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
||||||
|
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
||||||
|
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||||
|
return latents_next
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ dataset_snapshot_download(
|
|||||||
allow_file_pattern="inpaint/*.jpg"
|
allow_file_pattern="inpaint/*.jpg"
|
||||||
)
|
)
|
||||||
prompt = "a cat with sunglasses"
|
prompt = "a cat with sunglasses"
|
||||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
|
||||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt, seed=0,
|
prompt, seed=0,
|
||||||
|
input_image=controlnet_image, inpaint_mask=inpaint_mask,
|
||||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||||
height=1024, width=1024,
|
|
||||||
num_inference_steps=40,
|
num_inference_steps=40,
|
||||||
)
|
)
|
||||||
image.save("image.jpg")
|
image.save("image.jpg")
|
||||||
|
|||||||
@@ -23,12 +23,12 @@ dataset_snapshot_download(
|
|||||||
allow_file_pattern="inpaint/*.jpg"
|
allow_file_pattern="inpaint/*.jpg"
|
||||||
)
|
)
|
||||||
prompt = "a cat with sunglasses"
|
prompt = "a cat with sunglasses"
|
||||||
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
|
controlnet_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1328, 1328))
|
||||||
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
|
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1328, 1328))
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt, seed=0,
|
prompt, seed=0,
|
||||||
|
input_image=controlnet_image, inpaint_mask=inpaint_mask,
|
||||||
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
blockwise_controlnet_inputs=[ControlNetInput(image=controlnet_image, inpaint_mask=inpaint_mask)],
|
||||||
height=1024, width=1024,
|
|
||||||
num_inference_steps=40,
|
num_inference_steps=40,
|
||||||
)
|
)
|
||||||
image.save("image.jpg")
|
image.save("image.jpg")
|
||||||
|
|||||||
Reference in New Issue
Block a user