This commit is contained in:
Artiprocher
2025-06-27 18:38:40 +08:00
parent fcf2fbc07f
commit 009f26bb40
5 changed files with 144 additions and 5 deletions

View File

@@ -102,6 +102,7 @@ class FluxImagePipeline(BasePipeline):
FluxImageUnit_InputImageEmbedder(),
FluxImageUnit_ImageIDs(),
FluxImageUnit_EmbeddedGuidanceEmbedder(),
FluxImageUnit_Kontext(),
FluxImageUnit_InfiniteYou(),
FluxImageUnit_ControlNet(),
FluxImageUnit_IPAdapter(),
@@ -211,6 +212,8 @@ class FluxImagePipeline(BasePipeline):
multidiffusion_prompts=(),
multidiffusion_masks=(),
multidiffusion_scales=(),
# Kontext
kontext_images: Union[list[Image.Image], Image.Image] = None,
# ControlNet
controlnet_inputs: list[ControlNetInput] = None,
# IP-Adapter
@@ -257,6 +260,7 @@ class FluxImagePipeline(BasePipeline):
"seed": seed, "rand_device": rand_device,
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
"kontext_images": kontext_images,
"controlnet_inputs": controlnet_inputs,
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
@@ -378,6 +382,32 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
class FluxImageUnit_Kontext(PipelineUnit):
def __init__(self):
super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride"))
def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride):
if kontext_images is None:
return {}
if not isinstance(kontext_images, list):
kontext_images = [kontext_images]
kontext_latents = []
kontext_image_ids = []
for kontext_image in kontext_images:
kontext_image = pipe.preprocess_image(kontext_image)
kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image_ids = pipe.dit.prepare_image_ids(kontext_latent)
image_ids[..., 0] = 1
kontext_image_ids.append(image_ids)
kontext_latent = pipe.dit.patchify(kontext_latent)
kontext_latents.append(kontext_latent)
kontext_latents = torch.concat(kontext_latents, dim=1)
kontext_image_ids = torch.concat(kontext_image_ids, dim=-2)
return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids}
class FluxImageUnit_ControlNet(PipelineUnit):
def __init__(self):
super().__init__(
@@ -688,6 +718,8 @@ def model_fn_flux_image(
guidance=None,
text_ids=None,
image_ids=None,
kontext_latents=None,
kontext_image_ids=None,
controlnet_inputs=None,
controlnet_conditionings=None,
tiled=False,
@@ -787,6 +819,11 @@ def model_fn_flux_image(
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
# Kontext
if kontext_latents is not None:
image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2)
hidden_states = torch.concat([hidden_states, kontext_latents], dim=1)
# Step1x
if step1x_reference_latents is not None:
step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
@@ -827,7 +864,10 @@ def model_fn_flux_image(
)
# ControlNet
if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
if kontext_latents is None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
else:
hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
@@ -846,7 +886,10 @@ def model_fn_flux_image(
)
# ControlNet
if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
if kontext_latents is None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
else:
hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id]
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
if tea_cache is not None:
@@ -858,6 +901,10 @@ def model_fn_flux_image(
# Step1x
if step1x_reference_latents is not None:
hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
# Kontext
if kontext_latents is not None:
hidden_states = hidden_states[:, :-kontext_latents.shape[1]]
hidden_states = dit.unpatchify(hidden_states, height, width)