mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
flex full support
This commit is contained in:
@@ -362,7 +362,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
return {}, controlnet_image
|
||||
|
||||
|
||||
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, tiled=False, tile_size=64, tile_stride=32):
|
||||
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
|
||||
if self.dit.input_dim == 196:
|
||||
if flex_inpaint_image is None:
|
||||
flex_inpaint_image = torch.zeros_like(latents)
|
||||
@@ -375,13 +375,16 @@ class FluxImagePipeline(BasePipeline):
|
||||
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
|
||||
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
|
||||
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
|
||||
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
|
||||
if flex_control_image is None:
|
||||
flex_control_image = torch.zeros_like(latents)
|
||||
else:
|
||||
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
|
||||
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
|
||||
flex_kwargs = {"flex_condition": flex_condition}
|
||||
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
|
||||
flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
|
||||
flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
|
||||
else:
|
||||
flex_kwargs = {}
|
||||
return flex_kwargs
|
||||
@@ -427,6 +430,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
flex_inpaint_image=None,
|
||||
flex_inpaint_mask=None,
|
||||
flex_control_image=None,
|
||||
flex_control_strength=0.5,
|
||||
flex_control_stop=0.5,
|
||||
# TeaCache
|
||||
tea_cache_l1_thresh=None,
|
||||
# Tile
|
||||
@@ -635,6 +640,8 @@ def lets_dance_flux(
|
||||
id_emb=None,
|
||||
infinityou_guidance=None,
|
||||
flex_condition=None,
|
||||
flex_uncondition=None,
|
||||
flex_control_stop_timestep=None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -686,8 +693,12 @@ def lets_dance_flux(
|
||||
controlnet_frames, **controlnet_extra_kwargs
|
||||
)
|
||||
|
||||
# Flex
|
||||
if flex_condition is not None:
|
||||
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||
if timestep.tolist()[0] >= flex_control_stop_timestep:
|
||||
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
|
||||
else:
|
||||
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = dit.prepare_image_ids(hidden_states)
|
||||
|
||||
@@ -23,22 +23,27 @@ image = pipe(
|
||||
image.save("image_1.jpg")
|
||||
|
||||
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
||||
mask[0:300, 300:800] = 255
|
||||
mask[200:400, 400:700] = 255
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save("image_mask.jpg")
|
||||
|
||||
inpaint_image = np.array(image)
|
||||
inpaint_image[0:300, 300:800] = 0
|
||||
inpaint_image = Image.fromarray(inpaint_image)
|
||||
inpaint_image.save("image_inpaint.jpg")
|
||||
inpaint_image = image
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
|
||||
seed=4
|
||||
)
|
||||
image.save("image_2.jpg")
|
||||
|
||||
control_image = Annotator("canny")(image)
|
||||
control_image.save("image_control.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt="portrait of a beautiful Asian girl, long hair, yellow t-shirt, sunshine, beach",
|
||||
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
|
||||
num_inference_steps=50, embedded_guidance=3.5,
|
||||
flex_control_image=control_image,
|
||||
seed=4
|
||||
)
|
||||
image.save("image_2.jpg")
|
||||
image.save("image_3.jpg")
|
||||
|
||||
Reference in New Issue
Block a user