From ccf24c363fcac755ffb24ede784af2f437b9c9cf Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 24 Apr 2025 19:18:54 +0800 Subject: [PATCH] flex control --- diffsynth/pipelines/flux_image.py | 14 +++++--- .../image_synthesis/flex_text_to_image.py | 34 +++++++++++++++---- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 647ff2d..5fa807d 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -362,20 +362,24 @@ class FluxImagePipeline(BasePipeline): return {}, controlnet_image - def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None): + def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, tiled=False, tile_size=64, tile_stride=32): if self.dit.input_dim == 196: if flex_inpaint_image is None: flex_inpaint_image = torch.zeros_like(latents) else: - pass # TODO + flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype) + flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) if flex_inpaint_mask is None: flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] else: - pass # TODO + flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2])) + flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype) + flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2 if flex_control_image is None: flex_control_image = torch.zeros_like(latents) else: - pass # TODO + flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype) + flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) flex_kwargs = {"flex_condition": flex_condition} else: @@ -463,7 +467,7 @@ class FluxImagePipeline(BasePipeline): controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) # Flex - flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image) + flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, **tiler_kwargs) # TeaCache tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} diff --git a/examples/image_synthesis/flex_text_to_image.py b/examples/image_synthesis/flex_text_to_image.py index 288d4a6..5a56923 100644 --- a/examples/image_synthesis/flex_text_to_image.py +++ b/examples/image_synthesis/flex_text_to_image.py @@ -1,5 +1,8 @@ import torch from diffsynth import ModelManager, FluxImagePipeline, download_models +from diffsynth.controlnets.processors import Annotator +import numpy as np +from PIL import Image download_models(["FLUX.1-dev"]) @@ -12,11 +15,30 @@ model_manager.load_models([ ]) pipe = FluxImagePipeline.from_model_manager(model_manager) -prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." - -torch.manual_seed(9) image = pipe( - prompt=prompt, - num_inference_steps=50, embedded_guidance=3.5 + prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + seed=0 ) -image.save("image_1024.jpg") +image.save("image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[0:300, 300:800] = 255 +mask = Image.fromarray(mask) +mask.save("image_mask.jpg") + +inpaint_image = np.array(image) +inpaint_image[0:300, 300:800] = 0 +inpaint_image = Image.fromarray(inpaint_image) +inpaint_image.save("image_inpaint.jpg") + +control_image = Annotator("canny")(image) +control_image.save("image_control.jpg") + +image = pipe( + prompt="portrait of a beautiful Asian girl, long hair, yellow t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_control_image=control_image, + seed=4 +) +image.save("image_2.jpg")