flex full support

This commit is contained in:
Artiprocher
2025-04-25 11:32:13 +08:00
parent ccf24c363f
commit 419ace37f3
2 changed files with 27 additions and 11 deletions

View File

@@ -362,7 +362,7 @@ class FluxImagePipeline(BasePipeline):
return {}, controlnet_image
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, tiled=False, tile_size=64, tile_stride=32):
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
if self.dit.input_dim == 196:
if flex_inpaint_image is None:
flex_inpaint_image = torch.zeros_like(latents)
@@ -375,13 +375,16 @@ class FluxImagePipeline(BasePipeline):
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
if flex_control_image is None:
flex_control_image = torch.zeros_like(latents)
else:
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
flex_kwargs = {"flex_condition": flex_condition}
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
else:
flex_kwargs = {}
return flex_kwargs
@@ -427,6 +430,8 @@ class FluxImagePipeline(BasePipeline):
flex_inpaint_image=None,
flex_inpaint_mask=None,
flex_control_image=None,
flex_control_strength=0.5,
flex_control_stop=0.5,
# TeaCache
tea_cache_l1_thresh=None,
# Tile
@@ -635,6 +640,8 @@ def lets_dance_flux(
id_emb=None,
infinityou_guidance=None,
flex_condition=None,
flex_uncondition=None,
flex_control_stop_timestep=None,
tea_cache: TeaCache = None,
**kwargs
):
@@ -686,8 +693,12 @@ def lets_dance_flux(
controlnet_frames, **controlnet_extra_kwargs
)
# Flex
if flex_condition is not None:
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
if timestep.tolist()[0] >= flex_control_stop_timestep:
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
else:
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
if image_ids is None:
image_ids = dit.prepare_image_ids(hidden_states)