Diffusion Templates framework

This commit is contained in:
Artiprocher
2026-04-08 15:25:33 +08:00
parent f88b99cb4f
commit 9f8c352a15
10 changed files with 526 additions and 241 deletions

View File

@@ -40,6 +40,7 @@ class Flux2ImagePipeline(BasePipeline):
Flux2Unit_InputImageEmbedder(),
Flux2Unit_EditImageEmbedder(),
Flux2Unit_ImageIDs(),
Flux2Unit_Inpaint(),
]
self.model_fn = model_fn_flux2
@@ -94,8 +95,15 @@ class Flux2ImagePipeline(BasePipeline):
# Steps
num_inference_steps: int = 30,
# KV Cache
skill_cache = None,
negative_skill_cache = None,
kv_cache = None,
negative_kv_cache = None,
# LoRA
lora = None,
negative_lora = None,
# Inpaint
inpaint_mask: Image.Image = None,
inpaint_blur_size: int = None,
inpaint_blur_sigma: float = None,
# Progress bar
progress_bar_cmd = tqdm,
):
@@ -104,11 +112,11 @@ class Flux2ImagePipeline(BasePipeline):
# Parameters
inputs_posi = {
"prompt": prompt,
"skill_cache": skill_cache,
"kv_cache": kv_cache,
}
inputs_nega = {
"negative_prompt": negative_prompt,
"skill_cache": negative_skill_cache,
"kv_cache": negative_kv_cache,
}
inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
@@ -117,6 +125,9 @@ class Flux2ImagePipeline(BasePipeline):
"height": height, "width": width,
"seed": seed, "rand_device": rand_device, "initial_noise": initial_noise,
"num_inference_steps": num_inference_steps,
"positive_only_lora": lora,
"negative_only_lora": negative_lora,
"inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -565,6 +576,26 @@ class Flux2Unit_ImageIDs(PipelineUnit):
return {"image_ids": image_ids}
class Flux2Unit_Inpaint(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
output_params=("inpaint_mask",),
)
def process(self, pipe: Flux2ImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
if inpaint_mask is None:
return {}
inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 16, height // 16)), min_value=0, max_value=1)
inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)
if inpaint_blur_size is not None and inpaint_blur_sigma is not None:
from torchvision.transforms import GaussianBlur
blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma)
inpaint_mask = blur(inpaint_mask)
inpaint_mask = rearrange(inpaint_mask, "B C H W -> B (H W) C")
return {"inpaint_mask": inpaint_mask}
def model_fn_flux2(
dit: Flux2DiT,
latents=None,
@@ -575,7 +606,7 @@ def model_fn_flux2(
image_ids=None,
edit_latents=None,
edit_image_ids=None,
skill_cache=None,
kv_cache=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
@@ -593,7 +624,7 @@ def model_fn_flux2(
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=image_ids,
kv_cache=skill_cache,
kv_cache=kv_cache,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)