mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
kontext
This commit is contained in:
@@ -102,6 +102,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_InputImageEmbedder(),
|
||||
FluxImageUnit_ImageIDs(),
|
||||
FluxImageUnit_EmbeddedGuidanceEmbedder(),
|
||||
FluxImageUnit_Kontext(),
|
||||
FluxImageUnit_InfiniteYou(),
|
||||
FluxImageUnit_ControlNet(),
|
||||
FluxImageUnit_IPAdapter(),
|
||||
@@ -211,6 +212,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
multidiffusion_prompts=(),
|
||||
multidiffusion_masks=(),
|
||||
multidiffusion_scales=(),
|
||||
# Kontext
|
||||
kontext_images: Union[list[Image.Image], Image.Image] = None,
|
||||
# ControlNet
|
||||
controlnet_inputs: list[ControlNetInput] = None,
|
||||
# IP-Adapter
|
||||
@@ -257,6 +260,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
|
||||
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
|
||||
"kontext_images": kontext_images,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
|
||||
@@ -378,6 +382,32 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
|
||||
|
||||
|
||||
|
||||
class FluxImageUnit_Kontext(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride"))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride):
|
||||
if kontext_images is None:
|
||||
return {}
|
||||
if not isinstance(kontext_images, list):
|
||||
kontext_images = [kontext_images]
|
||||
|
||||
kontext_latents = []
|
||||
kontext_image_ids = []
|
||||
for kontext_image in kontext_images:
|
||||
kontext_image = pipe.preprocess_image(kontext_image)
|
||||
kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image_ids = pipe.dit.prepare_image_ids(kontext_latent)
|
||||
image_ids[..., 0] = 1
|
||||
kontext_image_ids.append(image_ids)
|
||||
kontext_latent = pipe.dit.patchify(kontext_latent)
|
||||
kontext_latents.append(kontext_latent)
|
||||
kontext_latents = torch.concat(kontext_latents, dim=1)
|
||||
kontext_image_ids = torch.concat(kontext_image_ids, dim=-2)
|
||||
return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids}
|
||||
|
||||
|
||||
|
||||
class FluxImageUnit_ControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -688,6 +718,8 @@ def model_fn_flux_image(
|
||||
guidance=None,
|
||||
text_ids=None,
|
||||
image_ids=None,
|
||||
kontext_latents=None,
|
||||
kontext_image_ids=None,
|
||||
controlnet_inputs=None,
|
||||
controlnet_conditionings=None,
|
||||
tiled=False,
|
||||
@@ -787,6 +819,11 @@ def model_fn_flux_image(
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = dit.patchify(hidden_states)
|
||||
|
||||
# Kontext
|
||||
if kontext_latents is not None:
|
||||
image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2)
|
||||
hidden_states = torch.concat([hidden_states, kontext_latents], dim=1)
|
||||
|
||||
# Step1x
|
||||
if step1x_reference_latents is not None:
|
||||
step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
|
||||
@@ -827,7 +864,10 @@ def model_fn_flux_image(
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
if kontext_latents is None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
else:
|
||||
hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id]
|
||||
|
||||
# Single Blocks
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
@@ -846,7 +886,10 @@ def model_fn_flux_image(
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
if kontext_latents is None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
else:
|
||||
hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id]
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
if tea_cache is not None:
|
||||
@@ -858,6 +901,10 @@ def model_fn_flux_image(
|
||||
# Step1x
|
||||
if step1x_reference_latents is not None:
|
||||
hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
|
||||
|
||||
# Kontext
|
||||
if kontext_latents is not None:
|
||||
hidden_states = hidden_states[:, :-kontext_latents.shape[1]]
|
||||
|
||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user