support klein edit

This commit is contained in:
Artiprocher
2026-01-20 12:58:18 +08:00
parent 88497b5c13
commit a835df984c
18 changed files with 273 additions and 2 deletions

View File

@@ -37,6 +37,7 @@ class Flux2ImagePipeline(BasePipeline):
Flux2Unit_Qwen3PromptEmbedder(),
Flux2Unit_NoiseInitializer(),
Flux2Unit_InputImageEmbedder(),
Flux2Unit_EditImageEmbedder(),
Flux2Unit_ImageIDs(),
]
self.model_fn = model_fn_flux2
@@ -79,6 +80,9 @@ class Flux2ImagePipeline(BasePipeline):
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Edit
edit_image: Union[Image.Image, List[Image.Image]] = None,
edit_image_auto_resize: bool = True,
# Shape
height: int = 1024,
width: int = 1024,
@@ -102,6 +106,7 @@ class Flux2ImagePipeline(BasePipeline):
inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
"input_image": input_image, "denoising_strength": denoising_strength,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
@@ -456,6 +461,64 @@ class Flux2Unit_InputImageEmbedder(PipelineUnit):
return {"latents": latents, "input_latents": input_latents}
class Flux2Unit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "edit_image_auto_resize"),
output_params=("edit_latents", "edit_image_ids"),
onload_model_names=("vae",)
)
def calculate_dimensions(self, target_area, ratio):
import math
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
return width, height
def edit_image_auto_resize(self, edit_image):
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
return edit_image.resize((calculated_width, calculated_height))
def process_image_ids(self, image_latents, scale=10):
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(edit_image, Image.Image):
edit_image = [edit_image]
resized_edit_image, edit_latents = [], []
for image in edit_image:
# Preprocess
if edit_image_auto_resize is None or edit_image_auto_resize:
image = self.edit_image_auto_resize(image)
resized_edit_image.append(image)
# Encode
image = pipe.preprocess_image(image)
latents = pipe.vae.encode(image)
edit_latents.append(latents)
edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device)
edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids}
class Flux2Unit_ImageIDs(PipelineUnit):
def __init__(self):
super().__init__(
@@ -490,10 +553,17 @@ def model_fn_flux2(
prompt_embeds=None,
text_ids=None,
image_ids=None,
edit_latents=None,
edit_image_ids=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
image_seq_len = latents.shape[1]
if edit_latents is not None:
image_seq_len = latents.shape[1]
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
model_output = dit(
hidden_states=latents,
@@ -505,4 +575,5 @@ def model_fn_flux2(
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
model_output = model_output[:, :image_seq_len]
return model_output