flex full support

This commit is contained in:
Artiprocher
2025-04-25 11:32:13 +08:00
parent ccf24c363f
commit 419ace37f3
2 changed files with 27 additions and 11 deletions

View File

@@ -362,7 +362,7 @@ class FluxImagePipeline(BasePipeline):
return {}, controlnet_image
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, tiled=False, tile_size=64, tile_stride=32):
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
if self.dit.input_dim == 196:
if flex_inpaint_image is None:
flex_inpaint_image = torch.zeros_like(latents)
@@ -375,13 +375,16 @@ class FluxImagePipeline(BasePipeline):
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
if flex_control_image is None:
flex_control_image = torch.zeros_like(latents)
else:
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
flex_kwargs = {"flex_condition": flex_condition}
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
else:
flex_kwargs = {}
return flex_kwargs
@@ -427,6 +430,8 @@ class FluxImagePipeline(BasePipeline):
flex_inpaint_image=None,
flex_inpaint_mask=None,
flex_control_image=None,
flex_control_strength=0.5,
flex_control_stop=0.5,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
@@ -635,6 +640,8 @@ def lets_dance_flux(
id_emb=None,
infinityou_guidance=None,
flex_condition=None,
flex_uncondition=None,
flex_control_stop_timestep=None,
tea_cache: TeaCache = None,
**kwargs
):
@@ -686,8 +693,12 @@ def lets_dance_flux(
controlnet_frames, **controlnet_extra_kwargs
)
# Flex
if flex_condition is not None:
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
if timestep.tolist()[0] >= flex_control_stop_timestep:
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
else:
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)

View File

@@ -23,22 +23,27 @@ image = pipe(
image.save("image_1.jpg")
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
mask[0:300, 300:800] = 255
mask[200:400, 400:700] = 255
mask = Image.fromarray(mask)
mask.save("image_mask.jpg")
inpaint_image = np.array(image)
inpaint_image[0:300, 300:800] = 0
inpaint_image = Image.fromarray(inpaint_image)
inpaint_image.save("image_inpaint.jpg")
inpaint_image = image
image = pipe(
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
seed=4
)
image.save("image_2.jpg")
control_image = Annotator("canny")(image)
control_image.save("image_control.jpg")
image = pipe(
prompt="portrait of a beautiful Asian girl, long hair, yellow t-shirt, sunshine, beach",
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
num_inference_steps=50, embedded_guidance=3.5,
flex_control_image=control_image,
seed=4
)
image.save("image_2.jpg")
image.save("image_3.jpg")