mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support klein edit
This commit is contained in:
@@ -37,6 +37,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
Flux2Unit_Qwen3PromptEmbedder(),
|
||||
Flux2Unit_NoiseInitializer(),
|
||||
Flux2Unit_InputImageEmbedder(),
|
||||
Flux2Unit_EditImageEmbedder(),
|
||||
Flux2Unit_ImageIDs(),
|
||||
]
|
||||
self.model_fn = model_fn_flux2
|
||||
@@ -79,6 +80,9 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Edit
|
||||
edit_image: Union[Image.Image, List[Image.Image]] = None,
|
||||
edit_image_auto_resize: bool = True,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
@@ -102,6 +106,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
@@ -456,6 +461,64 @@ class Flux2Unit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
class Flux2Unit_EditImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image", "edit_image_auto_resize"),
|
||||
output_params=("edit_latents", "edit_image_ids"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def calculate_dimensions(self, target_area, ratio):
|
||||
import math
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
return width, height
|
||||
|
||||
def edit_image_auto_resize(self, edit_image):
|
||||
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
|
||||
return edit_image.resize((calculated_width, calculated_height))
|
||||
|
||||
def process_image_ids(self, image_latents, scale=10):
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if isinstance(edit_image, Image.Image):
|
||||
edit_image = [edit_image]
|
||||
resized_edit_image, edit_latents = [], []
|
||||
for image in edit_image:
|
||||
# Preprocess
|
||||
if edit_image_auto_resize is None or edit_image_auto_resize:
|
||||
image = self.edit_image_auto_resize(image)
|
||||
resized_edit_image.append(image)
|
||||
# Encode
|
||||
image = pipe.preprocess_image(image)
|
||||
latents = pipe.vae.encode(image)
|
||||
edit_latents.append(latents)
|
||||
edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device)
|
||||
edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
|
||||
return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids}
|
||||
|
||||
|
||||
class Flux2Unit_ImageIDs(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -490,10 +553,17 @@ def model_fn_flux2(
|
||||
prompt_embeds=None,
|
||||
text_ids=None,
|
||||
image_ids=None,
|
||||
edit_latents=None,
|
||||
edit_image_ids=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
image_seq_len = latents.shape[1]
|
||||
if edit_latents is not None:
|
||||
image_seq_len = latents.shape[1]
|
||||
latents = torch.concat([latents, edit_latents], dim=1)
|
||||
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
|
||||
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
|
||||
model_output = dit(
|
||||
hidden_states=latents,
|
||||
@@ -505,4 +575,5 @@ def model_fn_flux2(
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
model_output = model_output[:, :image_seq_len]
|
||||
return model_output
|
||||
|
||||
Reference in New Issue
Block a user