mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-08 17:18:21 +00:00
Diffusion Templates framework
This commit is contained in:
@@ -40,6 +40,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
Flux2Unit_InputImageEmbedder(),
|
||||
Flux2Unit_EditImageEmbedder(),
|
||||
Flux2Unit_ImageIDs(),
|
||||
Flux2Unit_Inpaint(),
|
||||
]
|
||||
self.model_fn = model_fn_flux2
|
||||
|
||||
@@ -94,8 +95,15 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# KV Cache
|
||||
skill_cache = None,
|
||||
negative_skill_cache = None,
|
||||
kv_cache = None,
|
||||
negative_kv_cache = None,
|
||||
# LoRA
|
||||
lora = None,
|
||||
negative_lora = None,
|
||||
# Inpaint
|
||||
inpaint_mask: Image.Image = None,
|
||||
inpaint_blur_size: int = None,
|
||||
inpaint_blur_sigma: float = None,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
@@ -104,11 +112,11 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
"skill_cache": skill_cache,
|
||||
"kv_cache": kv_cache,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
"skill_cache": negative_skill_cache,
|
||||
"kv_cache": negative_kv_cache,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
|
||||
@@ -117,6 +125,9 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device, "initial_noise": initial_noise,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"positive_only_lora": lora,
|
||||
"negative_only_lora": negative_lora,
|
||||
"inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -565,6 +576,26 @@ class Flux2Unit_ImageIDs(PipelineUnit):
|
||||
return {"image_ids": image_ids}
|
||||
|
||||
|
||||
class Flux2Unit_Inpaint(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("inpaint_mask", "height", "width", "inpaint_blur_size", "inpaint_blur_sigma"),
|
||||
output_params=("inpaint_mask",),
|
||||
)
|
||||
|
||||
def process(self, pipe: Flux2ImagePipeline, inpaint_mask, height, width, inpaint_blur_size, inpaint_blur_sigma):
|
||||
if inpaint_mask is None:
|
||||
return {}
|
||||
inpaint_mask = pipe.preprocess_image(inpaint_mask.convert("RGB").resize((width // 16, height // 16)), min_value=0, max_value=1)
|
||||
inpaint_mask = inpaint_mask.mean(dim=1, keepdim=True)
|
||||
if inpaint_blur_size is not None and inpaint_blur_sigma is not None:
|
||||
from torchvision.transforms import GaussianBlur
|
||||
blur = GaussianBlur(kernel_size=inpaint_blur_size * 2 + 1, sigma=inpaint_blur_sigma)
|
||||
inpaint_mask = blur(inpaint_mask)
|
||||
inpaint_mask = rearrange(inpaint_mask, "B C H W -> B (H W) C")
|
||||
return {"inpaint_mask": inpaint_mask}
|
||||
|
||||
|
||||
def model_fn_flux2(
|
||||
dit: Flux2DiT,
|
||||
latents=None,
|
||||
@@ -575,7 +606,7 @@ def model_fn_flux2(
|
||||
image_ids=None,
|
||||
edit_latents=None,
|
||||
edit_image_ids=None,
|
||||
skill_cache=None,
|
||||
kv_cache=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
@@ -593,7 +624,7 @@ def model_fn_flux2(
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=image_ids,
|
||||
kv_cache=skill_cache,
|
||||
kv_cache=kv_cache,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user