flex control

This commit is contained in:
Artiprocher
2025-04-24 19:18:54 +08:00
parent b7a1ac6671
commit ccf24c363f
2 changed files with 37 additions and 11 deletions

View File

@@ -362,20 +362,24 @@ class FluxImagePipeline(BasePipeline):
return {}, controlnet_image
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None):
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, tiled=False, tile_size=64, tile_stride=32):
if self.dit.input_dim == 196:
if flex_inpaint_image is None:
flex_inpaint_image = torch.zeros_like(latents)
else:
pass # TODO
flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype)
flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if flex_inpaint_mask is None:
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
else:
pass # TODO
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
if flex_control_image is None:
flex_control_image = torch.zeros_like(latents)
else:
pass # TODO
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
flex_kwargs = {"flex_condition": flex_condition}
else:
@@ -463,7 +467,7 @@ class FluxImagePipeline(BasePipeline):
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# Flex
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image)
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, **tiler_kwargs)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}

View File

@@ -1,5 +1,8 @@
import torch
from diffsynth import ModelManager, FluxImagePipeline, download_models
from diffsynth.controlnets.processors import Annotator
import numpy as np
from PIL import Image
download_models(["FLUX.1-dev"])
@@ -12,11 +15,30 @@ model_manager.load_models([
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
torch.manual_seed(9)
image = pipe(
prompt=prompt,
num_inference_steps=50, embedded_guidance=3.5
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
seed=0
)
image.save("image_1024.jpg")
image.save("image_1.jpg")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[0:300, 300:800] = 255
mask = Image.fromarray(mask)
mask.save("image_mask.jpg")
inpaint_image = np.array(image)
inpaint_image[0:300, 300:800] = 0
inpaint_image = Image.fromarray(inpaint_image)
inpaint_image.save("image_inpaint.jpg")
control_image = Annotator("canny")(image)
control_image.save("image_control.jpg")
image = pipe(
prompt="portrait of a beautiful Asian girl, long hair, yellow t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
flex_control_image=control_image,
seed=4
)
image.save("image_2.jpg")